File tree Expand file tree Collapse file tree 1 file changed +15
-7
lines changed
Expand file tree Collapse file tree 1 file changed +15
-7
lines changed Original file line number Diff line number Diff line change @@ -41,7 +41,6 @@ class Frame(Iterable):
4141 def __post_init__ (self ):
4242 if not self .data .ndim == 3 :
4343 raise ValueError (f"data must be 3-dimensional, got { self .data .shape = } " )
44-
4544 self .pts_seconds = float (self .pts_seconds )
4645 self .duration_seconds = float (self .duration_seconds )
4746
@@ -92,12 +91,21 @@ def __iter__(self) -> Union[Iterator["FrameBatch"], Iterator[Frame]]:
9291 )
9392
9493 def __getitem__ (self , key ) -> Union ["FrameBatch" , Frame ]:
95- cls = Frame if self .data .ndim == 4 else FrameBatch
96- return cls (
97- self .data [key ],
98- self .pts_seconds [key ],
99- self .duration_seconds [key ],
100- )
94+ data = self .data [key ]
95+ pts_seconds = self .pts_seconds [key ]
96+ duration_seconds = self .duration_seconds [key ]
97+ if self .data .ndim == 4 :
98+ return Frame (
99+ data = data ,
100+ pts_seconds = float (pts_seconds .item ()),
101+ duration_seconds = float (duration_seconds .item ()),
102+ )
103+ else :
104+ return FrameBatch (
105+ data = data ,
106+ pts_seconds = pts_seconds ,
107+ duration_seconds = duration_seconds ,
108+ )
101109
102110 def __len__ (self ):
103111 return len (self .data )
You can’t perform that action at this time.
0 commit comments