Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(

FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
double startSeconds,
double stopSeconds) {
double stopSeconds,
std::optional<double> fps) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a slight preference for SingleStreamDecoder to not take optionals, and for the bridge code in custom_ops.cpp to take care of optional logic. But, there are other APIs in SingleStreamDecoder which do take optionals, I'll invite @Dan-Flores and @NicolasHug to make a call here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems natural to me to have an optional here. Letting custom_ops.cpp be the bridge, I think, would mean that we'd have to create more private method to split the implementation (for e.g. the common checks at the beginning, and the common logic at the end).

validateActiveStream(AVMEDIA_TYPE_VIDEO);
const auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];
Expand Down Expand Up @@ -906,6 +907,60 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
std::to_string(maxSeconds.value()) + ").");
}

// Resample frames to match the target frame rate
Copy link
Contributor Author

@mollyxu mollyxu Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add an early break if requested fps is the same as the current fps. Not sure if it is a necessary use case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's handle that with a follow-up. Because we'll need to do equality on a double, there may be some subtle things to handle.

if (fps.has_value()) {
TORCH_CHECK(
fps.value() > 0,
"fps must be positive, got " + std::to_string(fps.value()));

// TODO: add an early break if requested fps is the same as the current fps

double fpsVal = fps.value();
double frameDuration = 1.0 / fpsVal;

double product = (stopSeconds - startSeconds) * fpsVal;
int64_t numOutputFrames = static_cast<int64_t>(std::round(product));

// Generate target timestamps and find source frame indices
std::vector<int64_t> sourceFrameIndices(numOutputFrames);
std::vector<double> targetTimestamps(numOutputFrames);
for (int64_t i = 0; i < numOutputFrames; ++i) {
targetTimestamps[i] = startSeconds + i * frameDuration;
sourceFrameIndices[i] = secondsToIndexLowerBound(targetTimestamps[i]);
}

FrameBatchOutput frameBatchOutput(
numOutputFrames,
resizedOutputDims_.value_or(metadataDims_),
videoStreamOptions.device);

// Decode frames, reusing already-decoded frames for duplicates
int64_t lastDecodedSourceIndex = -1;
torch::Tensor lastDecodedData;

for (int64_t i = 0; i < numOutputFrames; ++i) {
int64_t sourceIdx = sourceFrameIndices[i];

if (sourceIdx == lastDecodedSourceIndex && lastDecodedSourceIndex >= 0) {
frameBatchOutput.data[i].copy_(lastDecodedData);
} else {
FrameOutput frameOutput =
getFrameAtIndexInternal(sourceIdx, frameBatchOutput.data[i]);
lastDecodedData = frameBatchOutput.data[i];
lastDecodedSourceIndex = sourceIdx;
}

frameBatchOutput.ptsSeconds[i] = targetTimestamps[i];
frameBatchOutput.durationSeconds[i] = frameDuration;
}

frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data);
return frameBatchOutput;
}

// Original behavior when fps is not specified:
// Return all frames in range at source fps

// Note that we look at nextPts for a frame, and not its pts or duration.
// Our abstract player displays frames starting at the pts for that frame
// until the pts for the next frame. There are two consequences:
Expand Down
11 changes: 5 additions & 6 deletions src/torchcodec/_core/SingleStreamDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,13 @@ class SingleStreamDecoder {
// Valid values for startSeconds and stopSeconds are:
//
// [beginStreamPtsSecondsFromContent, endStreamPtsSecondsFromContent)
//
// If fps is specified, frames are resampled to match the target frame
// rate by duplicating or dropping frames as necessary.
FrameBatchOutput getFramesPlayedInRange(
double startSeconds,
double stopSeconds);
double stopSeconds,
std::optional<double> fps = std::nullopt);

AudioFramesOutput getFramesPlayedInRangeAudio(
double startSeconds,
Expand Down Expand Up @@ -273,11 +277,6 @@ class SingleStreamDecoder {
UniqueAVFrame& avFrame,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

void convertAVFrameToFrameOutputOnCPU(
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

// --------------------------------------------------------------------------
// PTS <-> INDEX CONVERSIONS
// --------------------------------------------------------------------------
Expand Down
9 changes: 6 additions & 3 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds, float? fps=None) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> (Tensor, Tensor)");
m.def(
Expand Down Expand Up @@ -575,13 +575,16 @@ OpsFrameBatchOutput get_frames_by_pts(
// Return the frames inside the range as a single stacked Tensor. The range is
// defined as [start_seconds, stop_seconds). The frames are stacked in pts
// order.
// If fps is specified, frames are resampled to match the target frame
// rate by duplicating or dropping frames as necessary.
OpsFrameBatchOutput get_frames_by_pts_in_range(
at::Tensor& decoder,
double start_seconds,
double stop_seconds) {
double stop_seconds,
std::optional<double> fps = std::nullopt) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto result =
videoDecoder->getFramesPlayedInRange(start_seconds, stop_seconds);
videoDecoder->getFramesPlayedInRange(start_seconds, stop_seconds, fps);
return makeOpsFrameBatchOutput(result);
}

Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def get_frames_by_pts_in_range_abstract(
*,
start_seconds: float,
stop_seconds: float,
fps: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return (
Expand Down
40 changes: 28 additions & 12 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import torch
from torch import device as torch_device, nn, Tensor

from torchcodec import _core as core, Frame, FrameBatch
from torchcodec.decoders._decoder_utils import (
_get_cuda_backend,
Expand Down Expand Up @@ -452,32 +451,31 @@ def get_frames_played_at(self, seconds: torch.Tensor | list[float]) -> FrameBatc
)

def get_frames_played_in_range(
self, start_seconds: float, stop_seconds: float
self, start_seconds: float, stop_seconds: float, fps: float | None = None
Copy link
Contributor

@NicolasHug NicolasHug Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we expose the fps functionality in this method, then for convenience I think we should make sure to also expose a get_all_frames(fps) method, so that users don't have to manually specify both start and stop. (Exposing this get_all_frames() method is something we should do regardless).

) -> FrameBatch:
"""Returns multiple frames in the given range.

Frames are in the half open range [start_seconds, stop_seconds). Each
returned frame's :term:`pts`, in seconds, is inside of the half open
range.

Args:
start_seconds (float): Time, in seconds, of the start of the
range.
stop_seconds (float): Time, in seconds, of the end of the
range. As a half open range, the end is excluded.
start_seconds (float): Time, in seconds, of the start of the range.
stop_seconds (float): Time, in seconds, of the end of the range.
As a half open range, the end is excluded.
fps (float, optional): If specified, resample output to this frame
rate by duplicating or dropping frames as necessary. If None
(default), returns frames at the source video's frame rate.

Returns:
FrameBatch: The frames within the specified range.
"""
if not start_seconds <= stop_seconds:
raise ValueError(
f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})."
f"Invalid start seconds: {start_seconds}. "
f"It must be less than or equal to stop seconds ({stop_seconds})."
)
if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds:
raise ValueError(
f"Invalid start seconds: {start_seconds}. "
f"It must be greater than or equal to {self._begin_stream_seconds} "
f"and less than or equal to {self._end_stream_seconds}."
f"and less than {self._end_stream_seconds}."
)
if not stop_seconds <= self._end_stream_seconds:
raise ValueError(
Expand All @@ -488,9 +486,27 @@ def get_frames_played_in_range(
self._decoder,
start_seconds=start_seconds,
stop_seconds=stop_seconds,
fps=fps,
)
return FrameBatch(*frames)

def get_all_frames(self, fps: float | None = None) -> FrameBatch:
"""Returns all frames in the video.

Args:
fps (float, optional): If specified, resample output to this frame
rate by duplicating or dropping frames as necessary. If None
(default), returns frames at the source video's frame rate.

Returns:
FrameBatch: All frames in the video.
"""
return self.get_frames_played_in_range(
start_seconds=self._begin_stream_seconds,
stop_seconds=self._end_stream_seconds,
fps=fps,
)


def _get_and_validate_stream_metadata(
*,
Expand Down
Loading
Loading