Skip to content

Commit 6de8692

Browse files
author
Molly Xu
committed
Added proper tensor support for get_frames_played_at()
Summary: Modified get_frames_played_at in _video_decoder to accept tensors and updated all downstream functions to natively accept tensors rather than converting them to lists. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 0511d10 commit 6de8692

File tree

5 files changed

+36
-21
lines changed

5 files changed

+36
-21
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
753753
}
754754

755755
FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
756-
const std::vector<double>& timestamps) {
756+
const torch::Tensor& timestamps) {
757757
validateActiveStream(AVMEDIA_TYPE_VIDEO);
758758

759759
const auto& streamMetadata =
@@ -767,9 +767,13 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
767767
// avoid decoding that unique frame twice is to convert the input timestamps
768768
// to indices, and leverage the de-duplication logic of getFramesAtIndices.
769769

770-
std::vector<int64_t> frameIndices(timestamps.size());
771-
for (size_t i = 0; i < timestamps.size(); ++i) {
772-
auto frameSeconds = timestamps[i];
770+
torch::Tensor frameIndices =
771+
torch::empty({timestamps.numel()}, torch::kInt64);
772+
auto frameIndicesAccessor = frameIndices.accessor<int64_t, 1>();
773+
auto timestampsAccessor = timestamps.accessor<double, 1>();
774+
775+
for (int64_t i = 0; i < timestamps.numel(); ++i) {
776+
auto frameSeconds = timestampsAccessor[i];
773777
TORCH_CHECK(
774778
frameSeconds >= minSeconds,
775779
"frame pts is " + std::to_string(frameSeconds) +
@@ -786,11 +790,10 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
786790
".");
787791
}
788792

789-
frameIndices[i] = secondsToIndexLowerBound(frameSeconds);
793+
frameIndicesAccessor[i] = secondsToIndexLowerBound(frameSeconds);
790794
}
791795

792-
// TODO: Support tensors natively instead of a vector to avoid a copy.
793-
return getFramesAtIndices(torch::tensor(frameIndices));
796+
return getFramesAtIndices(frameIndices);
794797
}
795798

796799
FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class SingleStreamDecoder {
121121
// seconds=5.999, etc.
122122
FrameOutput getFramePlayedAt(double seconds);
123123

124-
FrameBatchOutput getFramesPlayedAt(const std::vector<double>& timestamps);
124+
FrameBatchOutput getFramesPlayedAt(const torch::Tensor& timestamps);
125125

126126
// Returns frames within a given pts range. The range is defined by
127127
// [startSeconds, stopSeconds) with respect to the pts values for frames. The

src/torchcodec/_core/custom_ops.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
6363
m.def(
6464
"get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> (Tensor, Tensor)");
6565
m.def(
66-
"get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
66+
"get_frames_by_pts(Tensor(a!) decoder, *, Tensor timestamps) -> (Tensor, Tensor, Tensor)");
6767
m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
6868
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
6969
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
@@ -405,10 +405,9 @@ OpsFrameBatchOutput get_frames_in_range(
405405
// Return the frames at given ptss for a given stream
406406
OpsFrameBatchOutput get_frames_by_pts(
407407
at::Tensor& decoder,
408-
at::ArrayRef<double> timestamps) {
408+
const at::Tensor& timestamps) {
409409
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
410-
std::vector<double> timestampsVec(timestamps.begin(), timestamps.end());
411-
auto result = videoDecoder->getFramesPlayedAt(timestampsVec);
410+
auto result = videoDecoder->getFramesPlayedAt(timestamps);
412411
return makeOpsFrameBatchOutput(result);
413412
}
414413

src/torchcodec/_core/ops.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def load_torchcodec_shared_libraries():
117117
_get_frames_at_indices_tensor_input = (
118118
torch.ops.torchcodec_ns.get_frames_at_indices.default
119119
)
120-
get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default
120+
_get_frames_by_pts_tensor_input = torch.ops.torchcodec_ns.get_frames_by_pts.default
121121
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
122122
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
123123
get_frames_by_pts_in_range_audio = (
@@ -212,6 +212,22 @@ def get_frames_at_indices(
212212
return _get_frames_at_indices_tensor_input(decoder, frame_indices=frame_indices)
213213

214214

215+
def get_frames_by_pts(
216+
decoder: torch.Tensor, *, timestamps: Union[torch.Tensor, list[float]]
217+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
218+
if isinstance(timestamps, torch.Tensor):
219+
# Ensure indices is the correct dtype (float64)
220+
timestamps = timestamps.to(torch.float64)
221+
else:
222+
# Convert list to tensor for dispatch
223+
try:
224+
timestamps = torch.tensor(timestamps, dtype=torch.float64)
225+
except (ValueError, TypeError):
226+
# Type validation in C++ layer
227+
pass
228+
return _get_frames_by_pts_tensor_input(decoder, timestamps=timestamps)
229+
230+
215231
# ==============================
216232
# Abstract impl for the operators. Needed by torch.compile.
217233
# ==============================
@@ -363,7 +379,7 @@ def get_frame_at_pts_abstract(
363379
def get_frames_by_pts_abstract(
364380
decoder: torch.Tensor,
365381
*,
366-
timestamps: List[float],
382+
timestamps: Union[torch.Tensor, List[float]],
367383
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
368384
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
369385
return (

src/torchcodec/decoders/_video_decoder.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,20 +336,17 @@ def get_frame_played_at(self, seconds: float) -> Frame:
336336
duration_seconds=duration_seconds.item(),
337337
)
338338

339-
def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
339+
def get_frames_played_at(
340+
self, seconds: Union[torch.Tensor, list[float]]
341+
) -> FrameBatch:
340342
"""Return frames played at the given timestamps in seconds.
341343
342344
Args:
343-
seconds (list of float): The timestamps in seconds when the frames are played.
345+
seconds (torch.Tensor or list of float): The timestamps in seconds when the frames are played.
344346
345347
Returns:
346348
FrameBatch: The frames that are played at ``seconds``.
347349
"""
348-
if isinstance(seconds, torch.Tensor):
349-
# TODO we should avoid converting tensors to lists and just let the
350-
# core ops and C++ code natively accept tensors. See
351-
# https://github.com/pytorch/torchcodec/issues/879
352-
seconds = seconds.to(torch.float).tolist()
353350

354351
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
355352
self._decoder, timestamps=seconds

0 commit comments

Comments
 (0)