diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 1794c984b..add9c9bee 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1090,6 +1090,53 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( return output; } +VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( + int streamIndex, + const std::vector& timestamps) { + validateUserProvidedStreamIndex(streamIndex); + validateScannedAllStreams("getFramesDisplayedByTimestamps"); + + // The frame displayed at timestamp t and the one displayed at timestamp `t + + // eps` are probably the same frame, with the same index. The easiest way to + // avoid decoding that unique frame twice is to convert the input timestamps + // to indices, and leverage the de-duplication logic of getFramesAtIndices. + // This means this function requires a scan. + // TODO: longer term, we should implement this without requiring a scan + + const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + const auto& stream = streams_[streamIndex]; + double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); + double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); + + std::vector frameIndices(timestamps.size()); + for (auto i = 0; i < timestamps.size(); ++i) { + auto framePts = timestamps[i]; + TORCH_CHECK( + framePts >= minSeconds && framePts < maxSeconds, + "frame pts is " + std::to_string(framePts) + "; must be in range [" + + std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) + + ")."); + + auto it = std::lower_bound( + stream.allFrames.begin(), + stream.allFrames.end(), + framePts, + [&stream](const FrameInfo& info, double framePts) { + return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; + }); + int64_t frameIndex = it - stream.allFrames.begin(); + // If the frame index is larger than the size of allFrames, that means we + // couldn't match the pts value to the pts value of a NEXT FRAME. And + // that means that this timestamp falls during the time between when the + // last frame is displayed, and the video ends. Hence, it should map to the + // index of the last frame. + frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1); + frameIndices[i] = frameIndex; + } + + return getFramesAtIndices(streamIndex, frameIndices); +} + VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( int streamIndex, int64_t start, diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 2adbfac64..c0f489cef 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -223,6 +223,7 @@ class VideoDecoder { // i.e. it will be returned when this function is called with seconds=5.0 or // seconds=5.999, etc. DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds); + DecodedOutput getFrameAtIndex( int streamIndex, int64_t frameIndex, @@ -242,6 +243,11 @@ class VideoDecoder { BatchDecodedOutput getFramesAtIndices( int streamIndex, const std::vector& frameIndices); + + BatchDecodedOutput getFramesDisplayedByTimestamps( + int streamIndex, + const std::vector& timestamps); + // Returns frames within a given range for a given stream as a single stacked // Tensor. The range is defined by [start, stop). The values retrieved from // the range are: diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 0be871a3e..6b91853c9 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -45,6 +45,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); + m.def( + "get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); m.def( @@ -240,6 +242,16 @@ OpsBatchDecodedOutput get_frames_in_range( stream_index, start, stop, step.value_or(1)); return makeOpsBatchDecodedOutput(result); } +OpsBatchDecodedOutput get_frames_by_pts( + at::Tensor& decoder, + int64_t stream_index, + at::ArrayRef timestamps) { + auto videoDecoder = unwrapTensorToGetDecoder(decoder); + std::vector timestampsVec(timestamps.begin(), timestamps.end()); + auto result = + videoDecoder->getFramesDisplayedByTimestamps(stream_index, timestampsVec); + return makeOpsBatchDecodedOutput(result); +} OpsBatchDecodedOutput get_frames_by_pts_in_range( at::Tensor& decoder, @@ -485,6 +497,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("get_frames_at_indices", &get_frames_at_indices); m.impl("get_frames_in_range", &get_frames_in_range); m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range); + m.impl("get_frames_by_pts", &get_frames_by_pts); m.impl("_test_frame_pts_equality", &_test_frame_pts_equality); m.impl( "scan_all_streams_to_update_metadata", diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 5b442025d..eac489cea 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -75,6 +75,12 @@ using OpsBatchDecodedOutput = std::tuple; // given timestamp T has T >= PTS and T < PTS + Duration. OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds); +// Return the frames at given ptss for a given stream +OpsBatchDecodedOutput get_frames_by_pts( + at::Tensor& decoder, + int64_t stream_index, + at::ArrayRef timestamps); + // Return the frame that is visible at a given index in the video. OpsDecodedOutput get_frame_at_index( at::Tensor& decoder, @@ -85,8 +91,7 @@ OpsDecodedOutput get_frame_at_index( // duration as tensors. OpsDecodedOutput get_next_frame(at::Tensor& decoder); -// Return the frames at a given index for a given stream as a single stacked -// Tensor. +// Return the frames at given indices for a given stream OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index bd761fe15..a1ac9a478 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -22,6 +22,7 @@ get_frame_at_index, get_frame_at_pts, get_frames_at_indices, + get_frames_by_pts, get_frames_by_pts_in_range, get_frames_in_range, get_json_metadata, diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 01de6ad67..d4102ae5d 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -71,6 +71,7 @@ def load_torchcodec_extension(): get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default +get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default @@ -172,6 +173,21 @@ def get_frame_at_pts_abstract( ) +@register_fake("torchcodec_ns::get_frames_by_pts") +def get_frames_by_pts_abstract( + decoder: torch.Tensor, + *, + stream_index: int, + timestamps: List[float], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + image_size = [get_ctx().new_dynamic_size() for _ in range(4)] + return ( + torch.empty(image_size), + torch.empty([], dtype=torch.float), + torch.empty([], dtype=torch.float), + ) + + @register_fake("torchcodec_ns::get_frame_at_index") def get_frame_at_index_abstract( decoder: torch.Tensor, *, stream_index: int, frame_index: int diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index bbd9fe4e7..0ed681469 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -27,6 +27,7 @@ get_frame_at_index, get_frame_at_pts, get_frames_at_indices, + get_frames_by_pts, get_frames_by_pts_in_range, get_frames_in_range, get_json_metadata, @@ -155,6 +156,36 @@ def test_get_frames_at_indices_unsorted_indices(self): with pytest.raises(AssertionError): assert_tensor_equal(frames[0], frames[-1]) + def test_get_frames_by_pts(self): + decoder = create_from_file(str(NASA_VIDEO.path)) + _add_video_stream(decoder) + scan_all_streams_to_update_metadata(decoder) + stream_index = 3 + + # Note: 13.01 should give the last video frame for the NASA video + timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3] + + expected_frames = [ + get_frame_at_pts(decoder, seconds=pts)[0] for pts in timestamps + ] + + frames, *_ = get_frames_by_pts( + decoder, + stream_index=stream_index, + timestamps=timestamps, + ) + for frame, expected_frame in zip(frames, expected_frames): + assert_tensor_equal(frame, expected_frame) + + # first and last frame should be equal, at pts=2 [+ eps]. We then modify + # the first frame and assert that it's now different from the last + # frame. This ensures a copy was properly made during the de-duplication + # logic. + assert_tensor_equal(frames[0], frames[-1]) + frames[0] += 20 + with pytest.raises(AssertionError): + assert_tensor_equal(frames[0], frames[-1]) + def test_get_frames_in_range(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder)