Skip to content

Commit c2e202d

Browse files
authored
Let get_frames_at and get_frames_played_at accept tensor indices (#880)
1 parent e12e466 commit c2e202d

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,12 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
247247
Returns:
248248
FrameBatch: The frames at the given indices.
249249
"""
250+
if isinstance(indices, torch.Tensor):
251+
# TODO we should avoid converting tensors to lists and just let the
252+
# core ops and C++ code natively accept tensors. See
253+
# https://github.com/pytorch/torchcodec/issues/879
254+
indices = indices.to(torch.int).tolist()
255+
250256
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
251257
self._decoder, frame_indices=indices
252258
)
@@ -322,6 +328,12 @@ def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
322328
Returns:
323329
FrameBatch: The frames that are played at ``seconds``.
324330
"""
331+
if isinstance(seconds, torch.Tensor):
332+
# TODO we should avoid converting tensors to lists and just let the
333+
# core ops and C++ code natively accept tensors. See
334+
# https://github.com/pytorch/torchcodec/issues/879
335+
seconds = seconds.to(torch.float).tolist()
336+
325337
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
326338
self._decoder, timestamps=seconds
327339
)

test/test_decoders.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,17 @@ def test_custom_frame_mappings_init_fails_invalid_json(self, tmp_path, device):
13901390
custom_frame_mappings=custom_frame_mappings,
13911391
)
13921392

1393+
def test_get_frames_at_tensor_indices(self):
1394+
# Non-regression test for tensor support in get_frames_at() and
1395+
# get_frames_played_at()
1396+
decoder = VideoDecoder(NASA_VIDEO.path)
1397+
1398+
decoder.get_frames_at(torch.tensor([0, 10], dtype=torch.int))
1399+
decoder.get_frames_at(torch.tensor([0, 10], dtype=torch.float))
1400+
1401+
decoder.get_frames_played_at(torch.tensor([0, 1], dtype=torch.int))
1402+
decoder.get_frames_played_at(torch.tensor([0, 1], dtype=torch.float))
1403+
13931404

13941405
class TestAudioDecoder:
13951406
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))

0 commit comments

Comments
 (0)