Skip to content

Commit 9fa4fd1

Browse files
authored
Added proper tensor support for get_frames_played_at() (#932)
Co-authored-by: Molly Xu <[email protected]> Resolved issue #879 #879
1 parent 6377dfc commit 9fa4fd1

File tree

6 files changed

+45
-24
lines changed

6 files changed

+45
-24
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
764764
}
765765

766766
FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
767-
const std::vector<double>& timestamps) {
767+
const torch::Tensor& timestamps) {
768768
validateActiveStream(AVMEDIA_TYPE_VIDEO);
769769

770770
const auto& streamMetadata =
@@ -778,9 +778,13 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
778778
// avoid decoding that unique frame twice is to convert the input timestamps
779779
// to indices, and leverage the de-duplication logic of getFramesAtIndices.
780780

781-
std::vector<int64_t> frameIndices(timestamps.size());
782-
for (size_t i = 0; i < timestamps.size(); ++i) {
783-
auto frameSeconds = timestamps[i];
781+
torch::Tensor frameIndices =
782+
torch::empty({timestamps.numel()}, torch::kInt64);
783+
auto frameIndicesAccessor = frameIndices.accessor<int64_t, 1>();
784+
auto timestampsAccessor = timestamps.accessor<double, 1>();
785+
786+
for (int64_t i = 0; i < timestamps.numel(); ++i) {
787+
auto frameSeconds = timestampsAccessor[i];
784788
TORCH_CHECK(
785789
frameSeconds >= minSeconds,
786790
"frame pts is " + std::to_string(frameSeconds) +
@@ -797,11 +801,10 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
797801
".");
798802
}
799803

800-
frameIndices[i] = secondsToIndexLowerBound(frameSeconds);
804+
frameIndicesAccessor[i] = secondsToIndexLowerBound(frameSeconds);
801805
}
802806

803-
// TODO: Support tensors natively instead of a vector to avoid a copy.
804-
return getFramesAtIndices(torch::tensor(frameIndices));
807+
return getFramesAtIndices(frameIndices);
805808
}
806809

807810
FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class SingleStreamDecoder {
123123
// seconds=5.999, etc.
124124
FrameOutput getFramePlayedAt(double seconds);
125125

126-
FrameBatchOutput getFramesPlayedAt(const std::vector<double>& timestamps);
126+
FrameBatchOutput getFramesPlayedAt(const torch::Tensor& timestamps);
127127

128128
// Returns frames within a given pts range. The range is defined by
129129
// [startSeconds, stopSeconds) with respect to the pts values for frames. The

src/torchcodec/_core/custom_ops.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
6363
m.def(
6464
"get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> (Tensor, Tensor)");
6565
m.def(
66-
"get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
66+
"get_frames_by_pts(Tensor(a!) decoder, *, Tensor timestamps) -> (Tensor, Tensor, Tensor)");
6767
m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
6868
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
6969
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
@@ -423,10 +423,9 @@ OpsFrameBatchOutput get_frames_in_range(
423423
// Return the frames at given ptss for a given stream
424424
OpsFrameBatchOutput get_frames_by_pts(
425425
at::Tensor& decoder,
426-
at::ArrayRef<double> timestamps) {
426+
const at::Tensor& timestamps) {
427427
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
428-
std::vector<double> timestampsVec(timestamps.begin(), timestamps.end());
429-
auto result = videoDecoder->getFramesPlayedAt(timestampsVec);
428+
auto result = videoDecoder->getFramesPlayedAt(timestamps);
430429
return makeOpsFrameBatchOutput(result);
431430
}
432431

src/torchcodec/_core/ops.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def load_torchcodec_shared_libraries():
117117
_get_frames_at_indices_tensor_input = (
118118
torch.ops.torchcodec_ns.get_frames_at_indices.default
119119
)
120-
get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default
120+
_get_frames_by_pts_tensor_input = torch.ops.torchcodec_ns.get_frames_by_pts.default
121121
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
122122
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
123123
get_frames_by_pts_in_range_audio = (
@@ -212,6 +212,21 @@ def get_frames_at_indices(
212212
return _get_frames_at_indices_tensor_input(decoder, frame_indices=frame_indices)
213213

214214

215+
def get_frames_by_pts(
216+
decoder: torch.Tensor, *, timestamps: Union[torch.Tensor, list[float]]
217+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
218+
if isinstance(timestamps, torch.Tensor):
219+
# Ensure indices is the correct dtype (float64)
220+
timestamps = timestamps.to(torch.float64)
221+
else:
222+
# Convert list to tensor for dispatch
223+
try:
224+
timestamps = torch.tensor(timestamps, dtype=torch.float64)
225+
except Exception as e:
226+
raise ValueError("Couldn't convert timestamps input to a tensor") from e
227+
return _get_frames_by_pts_tensor_input(decoder, timestamps=timestamps)
228+
229+
215230
# ==============================
216231
# Abstract impl for the operators. Needed by torch.compile.
217232
# ==============================
@@ -363,7 +378,7 @@ def get_frame_at_pts_abstract(
363378
def get_frames_by_pts_abstract(
364379
decoder: torch.Tensor,
365380
*,
366-
timestamps: List[float],
381+
timestamps: Union[torch.Tensor, List[float]],
367382
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
368383
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
369384
return (

src/torchcodec/decoders/_video_decoder.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -330,20 +330,17 @@ def get_frame_played_at(self, seconds: float) -> Frame:
330330
duration_seconds=duration_seconds.item(),
331331
)
332332

333-
def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
333+
def get_frames_played_at(
334+
self, seconds: Union[torch.Tensor, list[float]]
335+
) -> FrameBatch:
334336
"""Return frames played at the given timestamps in seconds.
335337
336338
Args:
337-
seconds (list of float): The timestamps in seconds when the frames are played.
339+
seconds (torch.Tensor or list of float): The timestamps in seconds when the frames are played.
338340
339341
Returns:
340342
FrameBatch: The frames that are played at ``seconds``.
341343
"""
342-
if isinstance(seconds, torch.Tensor):
343-
# TODO we should avoid converting tensors to lists and just let the
344-
# core ops and C++ code natively accept tensors. See
345-
# https://github.com/pytorch/torchcodec/issues/879
346-
seconds = seconds.to(torch.float).tolist()
347344

348345
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
349346
self._decoder, timestamps=seconds

test/test_decoders.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,13 +646,18 @@ def test_get_frame_played_at_fails(self, device, seek_mode):
646646

647647
@pytest.mark.parametrize("device", all_supported_devices())
648648
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
649-
def test_get_frames_played_at(self, device, seek_mode):
649+
@pytest.mark.parametrize("input_type", ("list", "tensor"))
650+
def test_get_frames_played_at(self, device, seek_mode, input_type):
650651
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
651652
device, _ = unsplit_device_str(device)
652653

653654
# Note: We know the frame at ~0.84s has index 25, the one at 1.16s has
654655
# index 35. We use those indices as reference to test against.
655-
seconds = [0.84, 1.17, 0.85]
656+
if input_type == "list":
657+
seconds = [0.84, 1.17, 0.85]
658+
else: # tensor
659+
seconds = torch.tensor([0.84, 1.17, 0.85])
660+
656661
reference_indices = [25, 35, 25]
657662
frames = decoder.get_frames_played_at(seconds)
658663

@@ -694,7 +699,9 @@ def test_get_frames_played_at_fails(self, device, seek_mode):
694699
with pytest.raises(RuntimeError, match="must be less than"):
695700
decoder.get_frames_played_at([14])
696701

697-
with pytest.raises(RuntimeError, match="Expected a value of type"):
702+
with pytest.raises(
703+
ValueError, match="Couldn't convert timestamps input to a tensor"
704+
):
698705
decoder.get_frames_played_at(["bad"])
699706

700707
@pytest.mark.parametrize("device", all_supported_devices())

0 commit comments

Comments
 (0)