Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,47 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
return output;
}

VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps(
Copy link
Contributor Author

@NicolasHug NicolasHug Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the function names, I based it on the existing getFramesDisplayedByTimestampsInRange(). I'm happy to bikeshed, but not on this PR please.
Note that our naming is becoming a bit inconsistent, e.g. the python names, ops names and C++ names are not always aligned. We probably want to clean that up, but that's for later.

We'll also have to re-think our public VideoDecoder method names very soon I think, because adding the 2 new "get_frames_..." that we recently added may conflict with existing names.

int streamIndex,
const std::vector<double>& timestamps) {
validateUserProvidedStreamIndex(streamIndex);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a TODO saying long term we should not require scanning the file for time-based frame extraction

validateScannedAllStreams("getFramesDisplayedByTimestamps");

// The frame displayed at timestamp t and the one displayed at timestamp `t +
// eps` are probably the same frame, with the same index. The easiest way to
// avoid decoding that unique frame twice is to convert the input timestamps
// to indices, and leverage the de-duplication logic of getFramesAtIndices.
Comment on lines +1096 to +1102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The long term thing to do here is to not require a scan for this function.

Scan should only be required for index-based functions where indexes need to be exact.

The index-based function can then call this function if the de-duping is done here and this doesn't need a scan.

That said you can merge this as-is and do the long-term thing later

Copy link
Contributor Author

@NicolasHug NicolasHug Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahmadsharif1 do you think it will be possible to correctly de-dup pts-based frame queries without going through indices?
That is currently the main reason we're converting to indices here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I mean by having explicit seeking modes (approximate versus exact) is a big conceptual change. :) It might only requires changing a few dozen lines of code, but it will require changes in many places, and we'll always have to be aware of what mode is active.

At the moment, we're implicitly exact when we need to do things with indices. I think we should just keep being that way as needed, and when we implement the different modes, we can reason about what changes to make wholistically.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible to dedup based on time, but it needs to be done on the fly -- i.e. when you decode a frame you see the extent of that frame on the timeline and can make copies of that for the timepoints that the user specified. It could be a bit more involved than that because you don't know the extent of the frame by looking at the frame itself -- you need to read (not decode) the next frame.

Doing it holistically in a future PR sounds good

// This means this function requires a scan.
// TODO: longer term, we should implement this without requiring a scan

const auto& streamMetadata = containerMetadata_.streams[streamIndex];
const auto& stream = streams_[streamIndex];
double minSeconds = streamMetadata.minPtsSecondsFromScan.value();
double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value();

std::vector<int64_t> frameIndices(timestamps.size());
for (auto i = 0; i < timestamps.size(); ++i) {
auto framePts = timestamps[i];
TORCH_CHECK(
framePts >= minSeconds && framePts < maxSeconds,
"frame pts is " + std::to_string(framePts) + "; must be in range [" +
std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) +
").");

auto it = std::lower_bound(
stream.allFrames.begin(),
stream.allFrames.end() - 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you intentionally not including the last frame? Note that by convention, the .end() method on C++ containers is actually an iterator past the end: https://en.cppreference.com/w/cpp/container/vector/end. So this code is actually eliminating the final element from the range.

Copy link
Contributor Author

@NicolasHug NicolasHug Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for double-checking, that was actually something that confused me a bit.

I originally had just stream.allFrames.end(), since this was copy-pasted from the getFramesDisplayedByTimestampInRange.

But with that, we would end up with an index of 390 on the NASA video for queries that were close to the video duration (~13s), when the last valid index is 389.

My first fix was to do frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1);.

But IIUC, the way std::lower_bound() works is that if it finds no frame that satisfies the condition, it will return the value of the second parameter (i.e. stream.allFrames.end() - 1), which is what we want?

(You'll see in the test I added that querying ~13s now properly returns the last frame, at index 389)

There is a very non-zero chance that I'm misunderstanding all of this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. I think you're right. We're eliminating the final element, but if we should have gotten the final element, it will become that anyway. This works for, I think, two resons:

  1. Because we've already eliminated the possibility of a pts value outside of the valid range because of the TORCH_CHECK above.
  2. Because we're converting iterators to indices, and .end() becomes equivalent to .size() with the subtraction method we're using.

But, I think I actually prefer using min to fix things up after. That this works is so subtle, and doing .end() - 1 in an iterator range is so unusual, I'd rather have the min approach with a comment explaining why we need it. I think it makes the code more understandable. The comment would be something like:

// If the frame index is larger than the size of all frames, that means we 
// couldn't match the pts value to the pts value of a NEXT FRAME. And 
// that means that this timestamp falls during the time between when the 
// last frame is displayed, and the video ends. Hence, it should map to the
// index of the last frame.

framePts,
[&stream](const FrameInfo& info, double framePts) {
return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts;
});
int64_t frameIndex = it - stream.allFrames.begin();
frameIndices[i] = frameIndex;
}

return getFramesAtIndices(streamIndex, frameIndices);
}

VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
int streamIndex,
int64_t start,
Expand Down
6 changes: 6 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ class VideoDecoder {
// i.e. it will be returned when this function is called with seconds=5.0 or
// seconds=5.999, etc.
DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds);

DecodedOutput getFrameAtIndex(
int streamIndex,
int64_t frameIndex,
Expand All @@ -242,6 +243,11 @@ class VideoDecoder {
BatchDecodedOutput getFramesAtIndices(
int streamIndex,
const std::vector<int64_t>& frameIndices);

BatchDecodedOutput getFramesDisplayedByTimestamps(
int streamIndex,
const std::vector<double>& timestamps);

// Returns frames within a given range for a given stream as a single stacked
// Tensor. The range is defined by [start, stop). The values retrieved from
// the range are:
Expand Down
13 changes: 13 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
m.def(
Expand Down Expand Up @@ -240,6 +242,16 @@ OpsBatchDecodedOutput get_frames_in_range(
stream_index, start, stop, step.value_or(1));
return makeOpsBatchDecodedOutput(result);
}
OpsBatchDecodedOutput get_frames_by_pts(
at::Tensor& decoder,
int64_t stream_index,
at::ArrayRef<double> timestamps) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
std::vector<double> timestampsVec(timestamps.begin(), timestamps.end());
auto result =
videoDecoder->getFramesDisplayedByTimestamps(stream_index, timestampsVec);
return makeOpsBatchDecodedOutput(result);
}

OpsBatchDecodedOutput get_frames_by_pts_in_range(
at::Tensor& decoder,
Expand Down Expand Up @@ -485,6 +497,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl("get_frames_at_indices", &get_frames_at_indices);
m.impl("get_frames_in_range", &get_frames_in_range);
m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range);
m.impl("get_frames_by_pts", &get_frames_by_pts);
m.impl("_test_frame_pts_equality", &_test_frame_pts_equality);
m.impl(
"scan_all_streams_to_update_metadata",
Expand Down
9 changes: 7 additions & 2 deletions src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ using OpsBatchDecodedOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
// given timestamp T has T >= PTS and T < PTS + Duration.
OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds);

// Return the frames at given ptss for a given stream
OpsBatchDecodedOutput get_frames_by_pts(
at::Tensor& decoder,
int64_t stream_index,
at::ArrayRef<double> timestamps);

// Return the frame that is visible at a given index in the video.
OpsDecodedOutput get_frame_at_index(
at::Tensor& decoder,
Expand All @@ -85,8 +91,7 @@ OpsDecodedOutput get_frame_at_index(
// duration as tensors.
OpsDecodedOutput get_next_frame(at::Tensor& decoder);

// Return the frames at a given index for a given stream as a single stacked
// Tensor.
// Return the frames at given indices for a given stream
OpsBatchDecodedOutput get_frames_at_indices(
at::Tensor& decoder,
int64_t stream_index,
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/decoders/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_frame_at_index,
get_frame_at_pts,
get_frames_at_indices,
get_frames_by_pts,
get_frames_by_pts_in_range,
get_frames_in_range,
get_json_metadata,
Expand Down
16 changes: 16 additions & 0 deletions src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def load_torchcodec_extension():
get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default
get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default
get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default
get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default
Expand Down Expand Up @@ -172,6 +173,21 @@ def get_frame_at_pts_abstract(
)


@register_fake("torchcodec_ns::get_frames_by_pts")
def get_frames_by_pts_abstract(
decoder: torch.Tensor,
*,
stream_index: int,
timestamps: List[float],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return (
torch.empty(image_size),
torch.empty([], dtype=torch.float),
torch.empty([], dtype=torch.float),
)


@register_fake("torchcodec_ns::get_frame_at_index")
def get_frame_at_index_abstract(
decoder: torch.Tensor, *, stream_index: int, frame_index: int
Expand Down
31 changes: 31 additions & 0 deletions test/decoders/test_video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
get_frame_at_index,
get_frame_at_pts,
get_frames_at_indices,
get_frames_by_pts,
get_frames_by_pts_in_range,
get_frames_in_range,
get_json_metadata,
Expand Down Expand Up @@ -155,6 +156,36 @@ def test_get_frames_at_indices_unsorted_indices(self):
with pytest.raises(AssertionError):
assert_tensor_equal(frames[0], frames[-1])

def test_get_frames_by_pts(self):
decoder = create_from_file(str(NASA_VIDEO.path))
_add_video_stream(decoder)
scan_all_streams_to_update_metadata(decoder)
stream_index = 3

# Note: 13.01 should give the last video frame for the NASA video
timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3]

expected_frames = [
get_frame_at_pts(decoder, seconds=pts)[0] for pts in timestamps
]

frames, *_ = get_frames_by_pts(
decoder,
stream_index=stream_index,
timestamps=timestamps,
)
for frame, expected_frame in zip(frames, expected_frames):
assert_tensor_equal(frame, expected_frame)

# # first and last frame should be equal, at pts=2 [+ eps]. We then
# modify the # first frame and assert that it's now different from the
# last frame. # This ensures a copy was properly made during the
# de-duplication logic.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: looks like auto-formatting had some problems when merging line comments.

assert_tensor_equal(frames[0], frames[-1])
frames[0] += 20
with pytest.raises(AssertionError):
assert_tensor_equal(frames[0], frames[-1])

def test_get_frames_in_range(self):
decoder = create_from_file(str(NASA_VIDEO.path))
scan_all_streams_to_update_metadata(decoder)
Expand Down
Loading