Skip to content

Commit 4661237

Browse files
committed
Fix mypy?
1 parent 1482529 commit 4661237

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

src/torchcodec/_frame.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ class Frame(Iterable):
4141
def __post_init__(self):
4242
if not self.data.ndim == 3:
4343
raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }")
44-
4544
self.pts_seconds = float(self.pts_seconds)
4645
self.duration_seconds = float(self.duration_seconds)
4746

@@ -92,12 +91,21 @@ def __iter__(self) -> Union[Iterator["FrameBatch"], Iterator[Frame]]:
9291
)
9392

9493
def __getitem__(self, key) -> Union["FrameBatch", Frame]:
95-
cls = Frame if self.data.ndim == 4 else FrameBatch
96-
return cls(
97-
self.data[key],
98-
self.pts_seconds[key],
99-
self.duration_seconds[key],
100-
)
94+
data = self.data[key]
95+
pts_seconds = self.pts_seconds[key]
96+
duration_seconds = self.duration_seconds[key]
97+
if self.data.ndim == 4:
98+
return Frame(
99+
data=data,
100+
pts_seconds=float(pts_seconds.item()),
101+
duration_seconds=float(duration_seconds.item()),
102+
)
103+
else:
104+
return FrameBatch(
105+
data=data,
106+
pts_seconds=pts_seconds,
107+
duration_seconds=duration_seconds,
108+
)
101109

102110
def __len__(self):
103111
return len(self.data)

0 commit comments

Comments
 (0)