Skip to content

Commit 1b08fac

Browse files
committed
Expose get_frames_displayed_at
1 parent 7d048df commit 1b08fac

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,24 @@ def get_frame_displayed_at(self, seconds: float) -> Frame:
264264
duration_seconds=duration_seconds.item(),
265265
)
266266

267+
def get_frames_displayed_at(self, seconds: list[float]) -> FrameBatch:
268+
"""Return frames displayed at the given timestamps in seconds.
269+
270+
Args:
271+
seconds (list of float): The timestamps in seconds when the frames are displayed.
272+
273+
Returns:
274+
FrameBatch: The frames that are displayed at ``seconds``.
275+
"""
276+
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
277+
self._decoder, timestamps=seconds, stream_index=self.stream_index
278+
)
279+
return FrameBatch(
280+
data=data,
281+
pts_seconds=pts_seconds,
282+
duration_seconds=duration_seconds,
283+
)
284+
267285
def get_frames_displayed_in_range(
268286
self, start_seconds: float, stop_seconds: float
269287
) -> FrameBatch:

test/decoders/test_video_decoder.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,13 @@ def test_get_frames_at(self):
366366

367367
def test_get_frames_at_fails(self):
368368
decoder = VideoDecoder(NASA_VIDEO.path)
369+
369370
with pytest.raises(RuntimeError, match="Invalid frame index=-1"):
370371
decoder.get_frames_at([-1])
372+
371373
with pytest.raises(RuntimeError, match="Invalid frame index=390"):
372374
decoder.get_frames_at([390])
375+
373376
with pytest.raises(RuntimeError, match="Expected a value of type"):
374377
decoder.get_frames_at([0.3])
375378

@@ -398,6 +401,47 @@ def test_get_frame_displayed_at_fails(self):
398401
with pytest.raises(IndexError, match="Invalid pts in seconds"):
399402
frame = decoder.get_frame_displayed_at(100.0) # noqa
400403

404+
def test_get_frames_displayed_at(self):
405+
406+
decoder = VideoDecoder(NASA_VIDEO.path)
407+
ref_frame6 = NASA_VIDEO.get_frame_by_name("time6.000000")
408+
ref_frame10 = NASA_VIDEO.get_frame_by_name("time10.000000")
409+
410+
seconds = [6.02, 10.01, 6.01]
411+
frames = decoder.get_frames_displayed_at(seconds)
412+
413+
assert isinstance(frames, FrameBatch)
414+
415+
assert_tensor_equal(frames.data[0], ref_frame6)
416+
assert_tensor_equal(frames.data[1], ref_frame10)
417+
assert_tensor_equal(frames.data[2], ref_frame6)
418+
419+
expected_pts_seconds = torch.tensor(
420+
[6.0060, 10.0100, 6.0060], dtype=torch.float64
421+
)
422+
torch.testing.assert_close(
423+
frames.pts_seconds, expected_pts_seconds, atol=1e-4, rtol=0
424+
)
425+
426+
expected_duration_seconds = torch.tensor(
427+
[0.0334, 0.0334, 0.0334], dtype=torch.float64
428+
)
429+
torch.testing.assert_close(
430+
frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0
431+
)
432+
433+
def test_get_frames_displayed_at_fails(self):
434+
decoder = VideoDecoder(NASA_VIDEO.path)
435+
436+
with pytest.raises(RuntimeError, match="must be in range"):
437+
decoder.get_frames_displayed_at([-1])
438+
439+
with pytest.raises(RuntimeError, match="must be in range"):
440+
decoder.get_frames_displayed_at([14])
441+
442+
with pytest.raises(RuntimeError, match="Expected a value of type"):
443+
decoder.get_frames_displayed_at(["bad"])
444+
401445
@pytest.mark.parametrize("stream_index", [0, 3, None])
402446
def test_get_frames_in_range(self, stream_index):
403447
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index)

0 commit comments

Comments
 (0)