Skip to content

Commit 97ca885

Browse files
authored
Expose new get_frames_at and get_frames_displayed_at methods (#303)
1 parent bc89ce1 commit 97ca885

File tree

6 files changed

+165
-37
lines changed

6 files changed

+165
-37
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,32 @@ def get_frame_at(self, index: int) -> Frame:
181181
duration_seconds=duration_seconds.item(),
182182
)
183183

184+
def get_frames_at(self, indices: list[int]) -> FrameBatch:
185+
"""Return frames at the given indices.
186+
187+
.. note::
188+
189+
Calling this method is more efficient that repeated individual calls
190+
to :meth:`~torchcodec.decoders.VideoDecoder.get_frame_at`. This
191+
method makes sure not to decode the same frame twice, and also
192+
avoids "backwards seek" operations, which are slow.
193+
194+
Args:
195+
indices (list of int): The indices of the frames to retrieve.
196+
197+
Returns:
198+
FrameBatch: The frames at the given indices.
199+
"""
200+
201+
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
202+
self._decoder, stream_index=self.stream_index, frame_indices=indices
203+
)
204+
return FrameBatch(
205+
data=data,
206+
pts_seconds=pts_seconds,
207+
duration_seconds=duration_seconds,
208+
)
209+
184210
def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatch:
185211
"""Return multiple frames at the given index range.
186212
@@ -238,6 +264,31 @@ def get_frame_displayed_at(self, seconds: float) -> Frame:
238264
duration_seconds=duration_seconds.item(),
239265
)
240266

267+
def get_frames_displayed_at(self, seconds: list[float]) -> FrameBatch:
268+
"""Return frames displayed at the given timestamps in seconds.
269+
270+
.. note::
271+
272+
Calling this method is more efficient that repeated individual calls
273+
to :meth:`~torchcodec.decoders.VideoDecoder.get_frame_displayed_at`.
274+
This method makes sure not to decode the same frame twice, and also
275+
avoids "backwards seek" operations, which are slow.
276+
277+
Args:
278+
seconds (list of float): The timestamps in seconds when the frames are displayed.
279+
280+
Returns:
281+
FrameBatch: The frames that are displayed at ``seconds``.
282+
"""
283+
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
284+
self._decoder, timestamps=seconds, stream_index=self.stream_index
285+
)
286+
return FrameBatch(
287+
data=data,
288+
pts_seconds=pts_seconds,
289+
duration_seconds=duration_seconds,
290+
)
291+
241292
def get_frames_displayed_in_range(
242293
self, start_seconds: float, stop_seconds: float
243294
) -> FrameBatch:

src/torchcodec/samplers/_common.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Callable, Union
22

3-
from torch import Tensor
43
from torchcodec import FrameBatch
54

65
_LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]]
@@ -58,17 +57,15 @@ def _validate_common_params(*, decoder, num_frames_per_clip, policy):
5857
)
5958

6059

61-
def _make_5d_framebatch(
60+
def _reshape_4d_framebatch_into_5d(
6261
*,
63-
data: Tensor,
64-
pts_seconds: Tensor,
65-
duration_seconds: Tensor,
62+
frames: FrameBatch,
6663
num_clips: int,
6764
num_frames_per_clip: int,
6865
) -> FrameBatch:
69-
last_3_dims = data.shape[-3:]
66+
last_3_dims = frames.data.shape[-3:]
7067
return FrameBatch(
71-
data=data.view(num_clips, num_frames_per_clip, *last_3_dims),
72-
pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip),
73-
duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip),
68+
data=frames.data.view(num_clips, num_frames_per_clip, *last_3_dims),
69+
pts_seconds=frames.pts_seconds.view(num_clips, num_frames_per_clip),
70+
duration_seconds=frames.duration_seconds.view(num_clips, num_frames_per_clip),
7471
)

src/torchcodec/samplers/_index_based.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44

55
from torchcodec import FrameBatch
66
from torchcodec.decoders import VideoDecoder
7-
from torchcodec.decoders._core import get_frames_at_indices
87
from torchcodec.samplers._common import (
9-
_make_5d_framebatch,
108
_POLICY_FUNCTION_TYPE,
119
_POLICY_FUNCTIONS,
10+
_reshape_4d_framebatch_into_5d,
1211
_validate_common_params,
1312
)
1413

@@ -177,16 +176,9 @@ def _generic_index_based_sampler(
177176
policy_fun=_POLICY_FUNCTIONS[policy],
178177
)
179178

180-
# TODO: Use public method of decoder, when it exists
181-
frames, pts_seconds, duration_seconds = get_frames_at_indices(
182-
decoder._decoder,
183-
stream_index=decoder.stream_index,
184-
frame_indices=all_clips_indices,
185-
)
186-
return _make_5d_framebatch(
187-
data=frames,
188-
pts_seconds=pts_seconds,
189-
duration_seconds=duration_seconds,
179+
frames = decoder.get_frames_at(indices=all_clips_indices)
180+
return _reshape_4d_framebatch_into_5d(
181+
frames=frames,
190182
num_clips=num_clips,
191183
num_frames_per_clip=num_frames_per_clip,
192184
)

src/torchcodec/samplers/_time_based.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
import torch
44

55
from torchcodec import FrameBatch
6-
from torchcodec.decoders._core import get_frames_by_pts
76
from torchcodec.samplers._common import (
8-
_make_5d_framebatch,
97
_POLICY_FUNCTION_TYPE,
108
_POLICY_FUNCTIONS,
9+
_reshape_4d_framebatch_into_5d,
1110
_validate_common_params,
1211
)
1312

@@ -210,16 +209,9 @@ def _generic_time_based_sampler(
210209
policy_fun=_POLICY_FUNCTIONS[policy],
211210
)
212211

213-
# TODO: Use public method of decoder, when it exists
214-
frames, pts_seconds, duration_seconds = get_frames_by_pts(
215-
decoder._decoder,
216-
stream_index=decoder.stream_index,
217-
timestamps=all_clips_timestamps,
218-
)
219-
return _make_5d_framebatch(
220-
data=frames,
221-
pts_seconds=pts_seconds,
222-
duration_seconds=duration_seconds,
212+
frames = decoder.get_frames_displayed_at(seconds=all_clips_timestamps)
213+
return _reshape_4d_framebatch_into_5d(
214+
frames=frames,
223215
num_clips=num_clips,
224216
num_frames_per_clip=num_frames_per_clip,
225217
)

test/decoders/test_video_decoder.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy
88
import pytest
99
import torch
10+
from torchcodec import FrameBatch
1011

1112
from torchcodec.decoders import _core, VideoDecoder
1213

@@ -301,9 +302,12 @@ def test_get_frame_at(self):
301302

302303
assert_tensor_equal(ref_frame9, frame9.data)
303304
assert isinstance(frame9.pts_seconds, float)
304-
assert frame9.pts_seconds == pytest.approx(0.3003)
305+
expected_frame_info = NASA_VIDEO.get_frame_info(9)
306+
assert frame9.pts_seconds == pytest.approx(expected_frame_info.pts_seconds)
305307
assert isinstance(frame9.duration_seconds, float)
306-
assert frame9.duration_seconds == pytest.approx(0.03337, rel=1e-3)
308+
assert frame9.duration_seconds == pytest.approx(
309+
expected_frame_info.duration_seconds, rel=1e-3
310+
)
307311

308312
# test numpy.int64
309313
frame9 = decoder.get_frame_at(numpy.int64(9))
@@ -340,6 +344,50 @@ def test_get_frame_at_fails(self):
340344
with pytest.raises(IndexError, match="out of bounds"):
341345
frame = decoder.get_frame_at(10000) # noqa
342346

347+
def test_get_frames_at(self):
348+
decoder = VideoDecoder(NASA_VIDEO.path)
349+
350+
frames = decoder.get_frames_at([35, 25])
351+
352+
assert isinstance(frames, FrameBatch)
353+
354+
assert_tensor_equal(frames[0].data, NASA_VIDEO.get_frame_data_by_index(35))
355+
assert_tensor_equal(frames[1].data, NASA_VIDEO.get_frame_data_by_index(25))
356+
357+
expected_pts_seconds = torch.tensor(
358+
[
359+
NASA_VIDEO.get_frame_info(35).pts_seconds,
360+
NASA_VIDEO.get_frame_info(25).pts_seconds,
361+
],
362+
dtype=torch.float64,
363+
)
364+
torch.testing.assert_close(
365+
frames.pts_seconds, expected_pts_seconds, atol=1e-4, rtol=0
366+
)
367+
368+
expected_duration_seconds = torch.tensor(
369+
[
370+
NASA_VIDEO.get_frame_info(35).duration_seconds,
371+
NASA_VIDEO.get_frame_info(25).duration_seconds,
372+
],
373+
dtype=torch.float64,
374+
)
375+
torch.testing.assert_close(
376+
frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0
377+
)
378+
379+
def test_get_frames_at_fails(self):
380+
decoder = VideoDecoder(NASA_VIDEO.path)
381+
382+
with pytest.raises(RuntimeError, match="Invalid frame index=-1"):
383+
decoder.get_frames_at([-1])
384+
385+
with pytest.raises(RuntimeError, match="Invalid frame index=390"):
386+
decoder.get_frames_at([390])
387+
388+
with pytest.raises(RuntimeError, match="Expected a value of type"):
389+
decoder.get_frames_at([0.3])
390+
343391
def test_get_frame_displayed_at(self):
344392
decoder = VideoDecoder(NASA_VIDEO.path)
345393

@@ -365,6 +413,51 @@ def test_get_frame_displayed_at_fails(self):
365413
with pytest.raises(IndexError, match="Invalid pts in seconds"):
366414
frame = decoder.get_frame_displayed_at(100.0) # noqa
367415

416+
def test_get_frames_displayed_at(self):
417+
418+
decoder = VideoDecoder(NASA_VIDEO.path)
419+
420+
# Note: We know the frame at ~0.84s has index 25, the one at 1.16s has
421+
# index 35. We use those indices as reference to test against.
422+
seconds = [0.84, 1.17, 0.85]
423+
reference_indices = [25, 35, 25]
424+
frames = decoder.get_frames_displayed_at(seconds)
425+
426+
assert isinstance(frames, FrameBatch)
427+
428+
for i in range(len(reference_indices)):
429+
assert_tensor_equal(
430+
frames.data[i], NASA_VIDEO.get_frame_data_by_index(reference_indices[i])
431+
)
432+
433+
expected_pts_seconds = torch.tensor(
434+
[NASA_VIDEO.get_frame_info(i).pts_seconds for i in reference_indices],
435+
dtype=torch.float64,
436+
)
437+
torch.testing.assert_close(
438+
frames.pts_seconds, expected_pts_seconds, atol=1e-4, rtol=0
439+
)
440+
441+
expected_duration_seconds = torch.tensor(
442+
[NASA_VIDEO.get_frame_info(i).duration_seconds for i in reference_indices],
443+
dtype=torch.float64,
444+
)
445+
torch.testing.assert_close(
446+
frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0
447+
)
448+
449+
def test_get_frames_displayed_at_fails(self):
450+
decoder = VideoDecoder(NASA_VIDEO.path)
451+
452+
with pytest.raises(RuntimeError, match="must be in range"):
453+
decoder.get_frames_displayed_at([-1])
454+
455+
with pytest.raises(RuntimeError, match="must be in range"):
456+
decoder.get_frames_displayed_at([14])
457+
458+
with pytest.raises(RuntimeError, match="Expected a value of type"):
459+
decoder.get_frames_displayed_at(["bad"])
460+
368461
@pytest.mark.parametrize("stream_index", [0, 3, None])
369462
def test_get_frames_in_range(self, stream_index):
370463
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index)
@@ -456,10 +549,11 @@ def test_get_frames_in_range(self, stream_index):
456549
(
457550
lambda decoder: decoder[0],
458551
lambda decoder: decoder.get_frame_at(0).data,
552+
lambda decoder: decoder.get_frames_at([0, 1]).data,
459553
lambda decoder: decoder.get_frames_in_range(0, 4).data,
460554
lambda decoder: decoder.get_frame_displayed_at(0).data,
461-
# TODO: uncomment once D60001893 lands
462-
# lambda decoder: decoder.get_frames_displayed_in_range(0, 1).data,
555+
lambda decoder: decoder.get_frames_displayed_at([0, 1]).data,
556+
lambda decoder: decoder.get_frames_displayed_in_range(0, 1).data,
463557
),
464558
)
465559
def test_dimension_order(self, dimension_order, frame_getter):

test/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor:
265265
8: TestFrameInfo(pts_seconds=0.266933, duration_seconds=0.033367),
266266
9: TestFrameInfo(pts_seconds=0.300300, duration_seconds=0.033367),
267267
10: TestFrameInfo(pts_seconds=0.333667, duration_seconds=0.033367),
268+
25: TestFrameInfo(pts_seconds=0.8342, duration_seconds=0.033367),
269+
35: TestFrameInfo(pts_seconds=1.1678, duration_seconds=0.033367),
268270
},
269271
},
270272
)

0 commit comments

Comments
 (0)