Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
}

FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
const std::vector<double>& timestamps) {
const torch::Tensor& timestamps) {
validateActiveStream(AVMEDIA_TYPE_VIDEO);

const auto& streamMetadata =
Expand All @@ -778,9 +778,13 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
// avoid decoding that unique frame twice is to convert the input timestamps
// to indices, and leverage the de-duplication logic of getFramesAtIndices.

std::vector<int64_t> frameIndices(timestamps.size());
for (size_t i = 0; i < timestamps.size(); ++i) {
auto frameSeconds = timestamps[i];
torch::Tensor frameIndices =
torch::empty({timestamps.numel()}, torch::kInt64);
auto frameIndicesAccessor = frameIndices.accessor<int64_t, 1>();
auto timestampsAccessor = timestamps.accessor<double, 1>();

for (int64_t i = 0; i < timestamps.numel(); ++i) {
auto frameSeconds = timestampsAccessor[i];
TORCH_CHECK(
frameSeconds >= minSeconds,
"frame pts is " + std::to_string(frameSeconds) +
Expand All @@ -797,11 +801,10 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
".");
}

frameIndices[i] = secondsToIndexLowerBound(frameSeconds);
frameIndicesAccessor[i] = secondsToIndexLowerBound(frameSeconds);
}

// TODO: Support tensors natively instead of a vector to avoid a copy.
return getFramesAtIndices(torch::tensor(frameIndices));
return getFramesAtIndices(frameIndices);
}

FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/SingleStreamDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class SingleStreamDecoder {
// seconds=5.999, etc.
FrameOutput getFramePlayedAt(double seconds);

FrameBatchOutput getFramesPlayedAt(const std::vector<double>& timestamps);
FrameBatchOutput getFramesPlayedAt(const torch::Tensor& timestamps);

// Returns frames within a given pts range. The range is defined by
// [startSeconds, stopSeconds) with respect to the pts values for frames. The
Expand Down
7 changes: 3 additions & 4 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> (Tensor, Tensor)");
m.def(
"get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
"get_frames_by_pts(Tensor(a!) decoder, *, Tensor timestamps) -> (Tensor, Tensor, Tensor)");
m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
Expand Down Expand Up @@ -423,10 +423,9 @@ OpsFrameBatchOutput get_frames_in_range(
// Return the frames at given ptss for a given stream
OpsFrameBatchOutput get_frames_by_pts(
at::Tensor& decoder,
at::ArrayRef<double> timestamps) {
const at::Tensor& timestamps) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
std::vector<double> timestampsVec(timestamps.begin(), timestamps.end());
auto result = videoDecoder->getFramesPlayedAt(timestampsVec);
auto result = videoDecoder->getFramesPlayedAt(timestamps);
return makeOpsFrameBatchOutput(result);
}

Expand Down
19 changes: 17 additions & 2 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def load_torchcodec_shared_libraries():
_get_frames_at_indices_tensor_input = (
torch.ops.torchcodec_ns.get_frames_at_indices.default
)
get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default
_get_frames_by_pts_tensor_input = torch.ops.torchcodec_ns.get_frames_by_pts.default
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
get_frames_by_pts_in_range_audio = (
Expand Down Expand Up @@ -212,6 +212,21 @@ def get_frames_at_indices(
return _get_frames_at_indices_tensor_input(decoder, frame_indices=frame_indices)


def get_frames_by_pts(
decoder: torch.Tensor, *, timestamps: Union[torch.Tensor, list[float]]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if isinstance(timestamps, torch.Tensor):
# Ensure indices is the correct dtype (float64)
timestamps = timestamps.to(torch.float64)
else:
# Convert list to tensor for dispatch
try:
timestamps = torch.tensor(timestamps, dtype=torch.float64)
except Exception as e:
raise ValueError("Couldn't convert timestamps input to a tensor") from e
return _get_frames_by_pts_tensor_input(decoder, timestamps=timestamps)


# ==============================
# Abstract impl for the operators. Needed by torch.compile.
# ==============================
Expand Down Expand Up @@ -363,7 +378,7 @@ def get_frame_at_pts_abstract(
def get_frames_by_pts_abstract(
decoder: torch.Tensor,
*,
timestamps: List[float],
timestamps: Union[torch.Tensor, List[float]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return (
Expand Down
11 changes: 4 additions & 7 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,20 +330,17 @@ def get_frame_played_at(self, seconds: float) -> Frame:
duration_seconds=duration_seconds.item(),
)

def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
def get_frames_played_at(
self, seconds: Union[torch.Tensor, list[float]]
) -> FrameBatch:
"""Return frames played at the given timestamps in seconds.

Args:
seconds (list of float): The timestamps in seconds when the frames are played.
seconds (torch.Tensor or list of float): The timestamps in seconds when the frames are played.

Returns:
FrameBatch: The frames that are played at ``seconds``.
"""
if isinstance(seconds, torch.Tensor):
# TODO we should avoid converting tensors to lists and just let the
# core ops and C++ code natively accept tensors. See
# https://github.com/pytorch/torchcodec/issues/879
seconds = seconds.to(torch.float).tolist()

data, pts_seconds, duration_seconds = core.get_frames_by_pts(
self._decoder, timestamps=seconds
Expand Down
13 changes: 10 additions & 3 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,13 +646,18 @@ def test_get_frame_played_at_fails(self, device, seek_mode):

@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_get_frames_played_at(self, device, seek_mode):
@pytest.mark.parametrize("input_type", ("list", "tensor"))
def test_get_frames_played_at(self, device, seek_mode, input_type):
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
device, _ = unsplit_device_str(device)

# Note: We know the frame at ~0.84s has index 25, the one at 1.16s has
# index 35. We use those indices as reference to test against.
seconds = [0.84, 1.17, 0.85]
if input_type == "list":
seconds = [0.84, 1.17, 0.85]
else: # tensor
seconds = torch.tensor([0.84, 1.17, 0.85])

reference_indices = [25, 35, 25]
frames = decoder.get_frames_played_at(seconds)

Expand Down Expand Up @@ -694,7 +699,9 @@ def test_get_frames_played_at_fails(self, device, seek_mode):
with pytest.raises(RuntimeError, match="must be less than"):
decoder.get_frames_played_at([14])

with pytest.raises(RuntimeError, match="Expected a value of type"):
with pytest.raises(
ValueError, match="Couldn't convert timestamps input to a tensor"
):
decoder.get_frames_played_at(["bad"])

@pytest.mark.parametrize("device", all_supported_devices())
Expand Down
Loading