Skip to content

Commit 1482529

Browse files
committed
Frame and FrameBatch improvements
1 parent c8de21c commit 1482529

File tree

2 files changed

+165
-3
lines changed

2 files changed

+165
-3
lines changed

src/torchcodec/_frame.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ class Frame(Iterable):
3838
duration_seconds: float
3939
"""The duration of the frame, in seconds (float)."""
4040

41+
def __post_init__(self):
42+
if not self.data.ndim == 3:
43+
raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }")
44+
45+
self.pts_seconds = float(self.pts_seconds)
46+
self.duration_seconds = float(self.duration_seconds)
47+
4148
def __iter__(self) -> Iterator[Union[Tensor, float]]:
4249
for field in dataclasses.fields(self):
4350
yield getattr(self, field.name)
@@ -57,9 +64,43 @@ class FrameBatch(Iterable):
5764
duration_seconds: Tensor
5865
"""The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""
5966

60-
def __iter__(self) -> Iterator[Union[Tensor, float]]:
61-
for field in dataclasses.fields(self):
62-
yield getattr(self, field.name)
67+
def __post_init__(self):
68+
if self.data.ndim < 4:
69+
raise ValueError(
70+
f"data must be at least 4-dimensional. Got {self.data.shape = } "
71+
"For 3-dimensional data, create a Frame object instead."
72+
)
73+
74+
leading_dims = self.data.shape[:-3]
75+
if not (leading_dims == self.pts_seconds.shape == self.duration_seconds.shape):
76+
raise ValueError(
77+
"Tried to create a FrameBatch but the leading dimensions of the inputs do not match. "
78+
f"Got {self.data.shape = } so we expected the shape of pts_seconds and "
79+
f"duration_seconds to be {leading_dims = }, but got "
80+
f"{self.pts_seconds.shape = } and {self.duration_seconds.shape = }."
81+
)
82+
83+
def __iter__(self) -> Union[Iterator["FrameBatch"], Iterator[Frame]]:
84+
cls = Frame if self.data.ndim == 4 else FrameBatch
85+
for data, pts_seconds, duration_seconds in zip(
86+
self.data, self.pts_seconds, self.duration_seconds
87+
):
88+
yield cls(
89+
data=data,
90+
pts_seconds=pts_seconds,
91+
duration_seconds=duration_seconds,
92+
)
93+
94+
def __getitem__(self, key) -> Union["FrameBatch", Frame]:
95+
cls = Frame if self.data.ndim == 4 else FrameBatch
96+
return cls(
97+
self.data[key],
98+
self.pts_seconds[key],
99+
self.duration_seconds[key],
100+
)
101+
102+
def __len__(self):
103+
return len(self.data)
63104

64105
def __repr__(self):
65106
return _frame_repr(self)

test/test_frame_dataclasses.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import pytest
2+
import torch
3+
from torchcodec import Frame, FrameBatch
4+
5+
6+
def test_frame_unpacking():
7+
data, pts_seconds, duration_seconds = Frame(torch.rand(3, 4, 5), 2, 3) # noqa
8+
9+
10+
def test_frame_error():
11+
with pytest.raises(ValueError, match="data must be 3-dimensional"):
12+
Frame(
13+
data=torch.rand(1, 2),
14+
pts_seconds=1,
15+
duration_seconds=1,
16+
)
17+
with pytest.raises(ValueError, match="data must be 3-dimensional"):
18+
Frame(
19+
data=torch.rand(1, 2, 3, 4),
20+
pts_seconds=1,
21+
duration_seconds=1,
22+
)
23+
24+
25+
def test_framebatch_error():
26+
with pytest.raises(ValueError, match="data must be at least 4-dimensional"):
27+
FrameBatch(
28+
data=torch.rand(1, 2, 3),
29+
pts_seconds=torch.rand(1),
30+
duration_seconds=torch.rand(1),
31+
)
32+
33+
with pytest.raises(
34+
ValueError, match="leading dimensions of the inputs do not match"
35+
):
36+
FrameBatch(
37+
data=torch.rand(3, 4, 2, 1),
38+
pts_seconds=torch.rand(3), # ok
39+
duration_seconds=torch.rand(2), # bad
40+
)
41+
42+
with pytest.raises(
43+
ValueError, match="leading dimensions of the inputs do not match"
44+
):
45+
FrameBatch(
46+
data=torch.rand(3, 4, 2, 1),
47+
pts_seconds=torch.rand(2), # bad
48+
duration_seconds=torch.rand(3), # ok
49+
)
50+
51+
with pytest.raises(
52+
ValueError, match="leading dimensions of the inputs do not match"
53+
):
54+
FrameBatch(
55+
data=torch.rand(5, 3, 4, 2, 1),
56+
pts_seconds=torch.rand(5, 3), # ok
57+
duration_seconds=torch.rand(5, 2), # bad
58+
)
59+
60+
with pytest.raises(
61+
ValueError, match="leading dimensions of the inputs do not match"
62+
):
63+
FrameBatch(
64+
data=torch.rand(5, 3, 4, 2, 1),
65+
pts_seconds=torch.rand(5, 2), # bad
66+
duration_seconds=torch.rand(5, 3), # ok
67+
)
68+
69+
70+
def test_framebatch_iteration():
71+
T, N, C, H, W = 7, 6, 3, 2, 4
72+
73+
fb = FrameBatch(
74+
data=torch.rand(T, N, C, H, W),
75+
pts_seconds=torch.rand(T, N),
76+
duration_seconds=torch.rand(T, N),
77+
)
78+
79+
for sub_fb in fb:
80+
assert isinstance(sub_fb, FrameBatch)
81+
assert sub_fb.data.shape == (N, C, H, W)
82+
assert sub_fb.pts_seconds.shape == (N,)
83+
assert sub_fb.duration_seconds.shape == (N,)
84+
for frame in sub_fb:
85+
assert isinstance(frame, Frame)
86+
assert frame.data.shape == (C, H, W)
87+
assert isinstance(frame.pts_seconds, float)
88+
assert isinstance(frame.duration_seconds, float)
89+
90+
# Check unpacking behavior
91+
first_sub_fb, *_ = fb
92+
assert isinstance(first_sub_fb, FrameBatch)
93+
94+
95+
def test_framebatch_indexing():
96+
T, N, C, H, W = 7, 6, 3, 2, 4
97+
98+
fb = FrameBatch(
99+
data=torch.rand(T, N, C, H, W),
100+
pts_seconds=torch.rand(T, N),
101+
duration_seconds=torch.rand(T, N),
102+
)
103+
104+
for i in range(len(fb)):
105+
assert isinstance(fb[i], FrameBatch)
106+
assert fb[i].data.shape == (N, C, H, W)
107+
assert fb[i].pts_seconds.shape == (N,)
108+
assert fb[i].duration_seconds.shape == (N,)
109+
for j in range(len(fb[i])):
110+
assert isinstance(fb[i][j], Frame)
111+
assert fb[i][j].data.shape == (C, H, W)
112+
assert isinstance(fb[i][j].pts_seconds, float)
113+
assert isinstance(fb[i][j].duration_seconds, float)
114+
115+
fb_fancy = fb[torch.arange(3)]
116+
assert isinstance(fb_fancy, FrameBatch)
117+
assert fb_fancy.data.shape == (3, N, C, H, W)
118+
119+
fb_fancy = fb[[[0], [1]]] # select T=0 and N=1.
120+
assert isinstance(fb_fancy, FrameBatch)
121+
assert fb_fancy.data.shape == (1, C, H, W)

0 commit comments

Comments
 (0)