Skip to content

Commit df7a464

Browse files
committed
Indexing 4D FrameBatch now returns FrameBatch
1 parent 9d7b240 commit df7a464

File tree

2 files changed

+38
-30
lines changed

2 files changed

+38
-30
lines changed

src/torchcodec/_frame.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,9 @@ class FrameBatch(Iterable):
6868
def __post_init__(self):
6969
# This is called after __init__() when a FrameBatch is created. We can
7070
# run input validation checks here.
71-
if self.data.ndim < 4:
71+
if self.data.ndim < 3:
7272
raise ValueError(
73-
f"data must be at least 4-dimensional. Got {self.data.shape = } "
74-
"For 3-dimensional data, create a Frame object instead."
73+
f"data must be at least 3-dimensional, got {self.data.shape = }"
7574
)
7675

7776
leading_dims = self.data.shape[:-3]
@@ -83,33 +82,22 @@ def __post_init__(self):
8382
f"{self.pts_seconds.shape = } and {self.duration_seconds.shape = }."
8483
)
8584

86-
def __iter__(self) -> Union[Iterator["FrameBatch"], Iterator[Frame]]:
87-
cls = Frame if self.data.ndim == 4 else FrameBatch
85+
def __iter__(self) -> Iterator["FrameBatch"]:
8886
for data, pts_seconds, duration_seconds in zip(
8987
self.data, self.pts_seconds, self.duration_seconds
9088
):
91-
yield cls(
89+
yield FrameBatch(
9290
data=data,
9391
pts_seconds=pts_seconds,
9492
duration_seconds=duration_seconds,
9593
)
9694

97-
def __getitem__(self, key) -> Union["FrameBatch", Frame]:
98-
data = self.data[key]
99-
pts_seconds = self.pts_seconds[key]
100-
duration_seconds = self.duration_seconds[key]
101-
if self.data.ndim == 4:
102-
return Frame(
103-
data=data,
104-
pts_seconds=float(pts_seconds.item()),
105-
duration_seconds=float(duration_seconds.item()),
106-
)
107-
else:
108-
return FrameBatch(
109-
data=data,
110-
pts_seconds=pts_seconds,
111-
duration_seconds=duration_seconds,
112-
)
95+
def __getitem__(self, key) -> "FrameBatch":
96+
return FrameBatch(
97+
data=self.data[key],
98+
pts_seconds=self.pts_seconds[key],
99+
duration_seconds=self.duration_seconds[key],
100+
)
113101

114102
def __len__(self):
115103
return len(self.data)

test/test_frame_dataclasses.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,18 @@ def test_frame_error():
2323

2424

2525
def test_framebatch_error():
26-
with pytest.raises(ValueError, match="data must be at least 4-dimensional"):
26+
with pytest.raises(ValueError, match="data must be at least 3-dimensional"):
27+
FrameBatch(
28+
data=torch.rand(2, 3),
29+
pts_seconds=torch.rand(1),
30+
duration_seconds=torch.rand(1),
31+
)
32+
33+
# Note: this is expected to fail because pts_seconds and duration_seconds
34+
# are expected to have a shape of size([]) instead of size([1]).
35+
with pytest.raises(
36+
ValueError, match="leading dimensions of the inputs do not match"
37+
):
2738
FrameBatch(
2839
data=torch.rand(1, 2, 3),
2940
pts_seconds=torch.rand(1),
@@ -82,10 +93,14 @@ def test_framebatch_iteration():
8293
assert sub_fb.pts_seconds.shape == (N,)
8394
assert sub_fb.duration_seconds.shape == (N,)
8495
for frame in sub_fb:
85-
assert isinstance(frame, Frame)
96+
assert isinstance(frame, FrameBatch)
8697
assert frame.data.shape == (C, H, W)
87-
assert isinstance(frame.pts_seconds, float)
88-
assert isinstance(frame.duration_seconds, float)
98+
# pts_seconds and duration_seconds are 0-dim tensors but they still
99+
# contain a value
100+
assert frame.pts_seconds.shape == tuple()
101+
assert frame.duration_seconds.shape == tuple()
102+
frame.pts_seconds.item()
103+
frame.duration_seconds.item()
89104

90105
# Check unpacking behavior
91106
first_sub_fb, *_ = fb
@@ -107,10 +122,15 @@ def test_framebatch_indexing():
107122
assert fb[i].pts_seconds.shape == (N,)
108123
assert fb[i].duration_seconds.shape == (N,)
109124
for j in range(len(fb[i])):
110-
assert isinstance(fb[i][j], Frame)
111-
assert fb[i][j].data.shape == (C, H, W)
112-
assert isinstance(fb[i][j].pts_seconds, float)
113-
assert isinstance(fb[i][j].duration_seconds, float)
125+
frame = fb[i][j]
126+
assert isinstance(frame, FrameBatch)
127+
assert frame.data.shape == (C, H, W)
128+
# pts_seconds and duration_seconds are 0-dim tensors but they still
129+
# contain a value
130+
assert frame.pts_seconds.shape == tuple()
131+
assert frame.duration_seconds.shape == tuple()
132+
frame.pts_seconds.item()
133+
frame.duration_seconds.item()
114134

115135
fb_fancy = fb[torch.arange(3)]
116136
assert isinstance(fb_fancy, FrameBatch)

0 commit comments

Comments
 (0)