@@ -23,7 +23,18 @@ def test_frame_error():
2323
2424
2525def test_framebatch_error ():
26- with pytest .raises (ValueError , match = "data must be at least 4-dimensional" ):
26+ with pytest .raises (ValueError , match = "data must be at least 3-dimensional" ):
27+ FrameBatch (
28+ data = torch .rand (2 , 3 ),
29+ pts_seconds = torch .rand (1 ),
30+ duration_seconds = torch .rand (1 ),
31+ )
32+
33+ # Note: this is expected to fail because pts_seconds and duration_seconds
34+ # are expected to have a shape of size([]) instead of size([1]).
35+ with pytest .raises (
36+ ValueError , match = "leading dimensions of the inputs do not match"
37+ ):
2738 FrameBatch (
2839 data = torch .rand (1 , 2 , 3 ),
2940 pts_seconds = torch .rand (1 ),
@@ -82,10 +93,14 @@ def test_framebatch_iteration():
8293 assert sub_fb .pts_seconds .shape == (N ,)
8394 assert sub_fb .duration_seconds .shape == (N ,)
8495 for frame in sub_fb :
85- assert isinstance (frame , Frame )
96+ assert isinstance (frame , FrameBatch )
8697 assert frame .data .shape == (C , H , W )
87- assert isinstance (frame .pts_seconds , float )
88- assert isinstance (frame .duration_seconds , float )
98+ # pts_seconds and duration_seconds are 0-dim tensors but they still
99+ # contain a value
100+ assert frame .pts_seconds .shape == tuple ()
101+ assert frame .duration_seconds .shape == tuple ()
102+ frame .pts_seconds .item ()
103+ frame .duration_seconds .item ()
89104
90105 # Check unpacking behavior
91106 first_sub_fb , * _ = fb
@@ -107,10 +122,15 @@ def test_framebatch_indexing():
107122 assert fb [i ].pts_seconds .shape == (N ,)
108123 assert fb [i ].duration_seconds .shape == (N ,)
109124 for j in range (len (fb [i ])):
110- assert isinstance (fb [i ][j ], Frame )
111- assert fb [i ][j ].data .shape == (C , H , W )
112- assert isinstance (fb [i ][j ].pts_seconds , float )
113- assert isinstance (fb [i ][j ].duration_seconds , float )
125+ frame = fb [i ][j ]
126+ assert isinstance (frame , FrameBatch )
127+ assert frame .data .shape == (C , H , W )
128+ # pts_seconds and duration_seconds are 0-dim tensors but they still
129+ # contain a value
130+ assert frame .pts_seconds .shape == tuple ()
131+ assert frame .duration_seconds .shape == tuple ()
132+ frame .pts_seconds .item ()
133+ frame .duration_seconds .item ()
114134
115135 fb_fancy = fb [torch .arange (3 )]
116136 assert isinstance (fb_fancy , FrameBatch )
0 commit comments