44// This source code is licensed under the BSD-style license found in the
55// LICENSE file in the root directory of this source tree.
66
7- #include " src/torchcodec/decoders/_core/VideoDecoderOps.h"
87#include < pybind11/pybind11.h>
98#include < cstdint>
109#include < sstream>
@@ -85,33 +84,82 @@ VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) {
8584 return decoder;
8685}
8786
87+ // The elements of this tuple are all tensors that represent a single frame:
88+ // 1. The frame data, which is a multidimensional tensor.
89+ // 2. A single float value for the pts in seconds.
90+ // 3. A single float value for the duration in seconds.
91+ // The reason we use Tensors for the second and third values is so we can run
92+ // under torch.compile().
93+ using OpsFrameOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
94+
8895OpsFrameOutput makeOpsFrameOutput (VideoDecoder::FrameOutput& frame) {
8996 return std::make_tuple (
9097 frame.data ,
9198 torch::tensor (frame.ptsSeconds , torch::dtype (torch::kFloat64 )),
9299 torch::tensor (frame.durationSeconds , torch::dtype (torch::kFloat64 )));
93100}
94101
102+ // All elements of this tuple are tensors of the same leading dimension. The
103+ // tuple represents the frames for N total frames, where N is the dimension of
104+ // each stacked tensor. The elments are:
105+ // 1. Stacked tensor of data for all N frames. Each frame is also a
106+ // multidimensional tensor.
107+ // 2. Tensor of N pts values in seconds, where each pts is a single
108+ // float.
109+ // 3. Tensor of N durationis in seconds, where each duration is a
110+ // single float.
111+ using OpsFrameBatchOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
112+
95113OpsFrameBatchOutput makeOpsFrameBatchOutput (
96114 VideoDecoder::FrameBatchOutput& batch) {
97115 return std::make_tuple (batch.data , batch.ptsSeconds , batch.durationSeconds );
98116}
99117
118+ // The elements of this tuple are all tensors that represent the concatenation
119+ // of multiple audio frames:
120+ // 1. The frames data (concatenated)
121+ // 2. A single float value for the pts of the first frame, in seconds.
122+ using OpsAudioFramesOutput = std::tuple<at::Tensor, at::Tensor>;
123+
100124OpsAudioFramesOutput makeOpsAudioFramesOutput (
101125 VideoDecoder::AudioFramesOutput& audioFrames) {
102126 return std::make_tuple (
103127 audioFrames.data ,
104128 torch::tensor (audioFrames.ptsSeconds , torch::dtype (torch::kFloat64 )));
105129}
130+
131+ std::string quoteValue (const std::string& value) {
132+ return " \" " + value + " \" " ;
133+ }
134+
135+ std::string mapToJson (const std::map<std::string, std::string>& metadataMap) {
136+ std::stringstream ss;
137+ ss << " {\n " ;
138+ auto it = metadataMap.begin ();
139+ while (it != metadataMap.end ()) {
140+ ss << " \" " << it->first << " \" : " << it->second ;
141+ ++it;
142+ if (it != metadataMap.end ()) {
143+ ss << " ,\n " ;
144+ } else {
145+ ss << " \n " ;
146+ }
147+ }
148+ ss << " }" ;
149+
150+ return ss.str ();
151+ }
152+
106153} // namespace
107154
108155// ==============================
109156// Implementations for the operators
110157// ==============================
111158
159+ // Create a VideoDecoder from file and wrap the pointer in a tensor.
112160at::Tensor create_from_file (
113161 std::string_view filename,
114- std::optional<std::string_view> seek_mode) {
162+ std::optional<std::string_view> seek_mode = std:: nullopt ) {
115163 std::string filenameStr (filename);
116164
117165 VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact;
@@ -125,9 +173,11 @@ at::Tensor create_from_file(
125173 return wrapDecoderPointerToTensor (std::move (uniqueDecoder));
126174}
127175
176+ // Create a VideoDecoder from the actual bytes of a video and wrap the pointer
177+ // in a tensor. The VideoDecoder will decode the provided bytes.
128178at::Tensor create_from_tensor (
129179 at::Tensor video_tensor,
130- std::optional<std::string_view> seek_mode) {
180+ std::optional<std::string_view> seek_mode = std:: nullopt ) {
131181 TORCH_CHECK (video_tensor.is_contiguous (), " video_tensor must be contiguous" );
132182 TORCH_CHECK (
133183 video_tensor.scalar_type () == torch::kUInt8 ,
@@ -153,33 +203,15 @@ at::Tensor _convert_to_tensor(int64_t decoder_ptr) {
153203 return wrapDecoderPointerToTensor (std::move (uniqueDecoder));
154204}
155205
156- void add_video_stream (
157- at::Tensor& decoder,
158- std::optional<int64_t > width,
159- std::optional<int64_t > height,
160- std::optional<int64_t > num_threads,
161- std::optional<std::string_view> dimension_order,
162- std::optional<int64_t > stream_index,
163- std::optional<std::string_view> device) {
164- _add_video_stream (
165- decoder,
166- width,
167- height,
168- num_threads,
169- dimension_order,
170- stream_index,
171- device);
172- }
173-
174206void _add_video_stream (
175207 at::Tensor& decoder,
176- std::optional<int64_t > width,
177- std::optional<int64_t > height,
178- std::optional<int64_t > num_threads,
179- std::optional<std::string_view> dimension_order,
180- std::optional<int64_t > stream_index,
181- std::optional<std::string_view> device,
182- std::optional<std::string_view> color_conversion_library) {
208+ std::optional<int64_t > width = std:: nullopt ,
209+ std::optional<int64_t > height = std:: nullopt ,
210+ std::optional<int64_t > num_threads = std:: nullopt ,
211+ std::optional<std::string_view> dimension_order = std:: nullopt ,
212+ std::optional<int64_t > stream_index = std:: nullopt ,
213+ std::optional<std::string_view> device = std:: nullopt ,
214+ std::optional<std::string_view> color_conversion_library = std:: nullopt ) {
183215 VideoDecoder::VideoStreamOptions videoStreamOptions;
184216 videoStreamOptions.width = width;
185217 videoStreamOptions.height = height;
@@ -221,22 +253,44 @@ void _add_video_stream(
221253 videoDecoder->addVideoStream (stream_index.value_or (-1 ), videoStreamOptions);
222254}
223255
256+ // Add a new video stream at `stream_index` using the provided options.
257+ void add_video_stream (
258+ at::Tensor& decoder,
259+ std::optional<int64_t > width = std::nullopt ,
260+ std::optional<int64_t > height = std::nullopt ,
261+ std::optional<int64_t > num_threads = std::nullopt ,
262+ std::optional<std::string_view> dimension_order = std::nullopt ,
263+ std::optional<int64_t > stream_index = std::nullopt ,
264+ std::optional<std::string_view> device = std::nullopt ) {
265+ _add_video_stream (
266+ decoder,
267+ width,
268+ height,
269+ num_threads,
270+ dimension_order,
271+ stream_index,
272+ device);
273+ }
274+
224275void add_audio_stream (
225276 at::Tensor& decoder,
226- std::optional<int64_t > stream_index,
227- std::optional<int64_t > sample_rate) {
277+ std::optional<int64_t > stream_index = std:: nullopt ,
278+ std::optional<int64_t > sample_rate = std:: nullopt ) {
228279 VideoDecoder::AudioStreamOptions audioStreamOptions;
229280 audioStreamOptions.sampleRate = sample_rate;
230281
231282 auto videoDecoder = unwrapTensorToGetDecoder (decoder);
232283 videoDecoder->addAudioStream (stream_index.value_or (-1 ), audioStreamOptions);
233284}
234285
286+ // Seek to a particular presentation timestamp in the video in seconds.
235287void seek_to_pts (at::Tensor& decoder, double seconds) {
236288 auto videoDecoder = static_cast <VideoDecoder*>(decoder.mutable_data_ptr ());
237289 videoDecoder->setCursorPtsInSeconds (seconds);
238290}
239291
292+ // Get the next frame from the video as a tuple that has the frame data, pts and
293+ // duration as tensors.
240294OpsFrameOutput get_next_frame (at::Tensor& decoder) {
241295 auto videoDecoder = unwrapTensorToGetDecoder (decoder);
242296 VideoDecoder::FrameOutput result;
@@ -248,6 +302,9 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder) {
248302 return makeOpsFrameOutput (result);
249303}
250304
305+ // Return the frame that is visible at a given timestamp in seconds. Each frame
306+ // in FFMPEG has a presentation timestamp and a duration. The frame visible at a
307+ // given timestamp T has T >= PTS and T < PTS + Duration.
251308OpsFrameOutput get_frame_at_pts (at::Tensor& decoder, double seconds) {
252309 auto videoDecoder = unwrapTensorToGetDecoder (decoder);
253310 VideoDecoder::FrameOutput result;
@@ -259,12 +316,14 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
259316 return makeOpsFrameOutput (result);
260317}
261318
319+ // Return the frame that is visible at a given index in the video.
262320OpsFrameOutput get_frame_at_index (at::Tensor& decoder, int64_t frame_index) {
263321 auto videoDecoder = unwrapTensorToGetDecoder (decoder);
264322 auto result = videoDecoder->getFrameAtIndex (frame_index);
265323 return makeOpsFrameOutput (result);
266324}
267325
326+ // Return the frames at given indices for a given stream
268327OpsFrameBatchOutput get_frames_at_indices (
269328 at::Tensor& decoder,
270329 at::IntArrayRef frame_indices) {
@@ -275,16 +334,19 @@ OpsFrameBatchOutput get_frames_at_indices(
275334 return makeOpsFrameBatchOutput (result);
276335}
277336
337+ // Return the frames inside a range as a single stacked Tensor. The range is
338+ // defined as [start, stop).
278339OpsFrameBatchOutput get_frames_in_range (
279340 at::Tensor& decoder,
280341 int64_t start,
281342 int64_t stop,
282- std::optional<int64_t > step) {
343+ std::optional<int64_t > step = std:: nullopt ) {
283344 auto videoDecoder = unwrapTensorToGetDecoder (decoder);
284345 auto result = videoDecoder->getFramesInRange (start, stop, step.value_or (1 ));
285346 return makeOpsFrameBatchOutput (result);
286347}
287348
349+ // Return the frames at given ptss for a given stream
288350OpsFrameBatchOutput get_frames_by_pts (
289351 at::Tensor& decoder,
290352 at::ArrayRef<double > timestamps) {
@@ -294,6 +356,9 @@ OpsFrameBatchOutput get_frames_by_pts(
294356 return makeOpsFrameBatchOutput (result);
295357}
296358
359+ // Return the frames inside the range as a single stacked Tensor. The range is
360+ // defined as [start_seconds, stop_seconds). The frames are stacked in pts
361+ // order.
297362OpsFrameBatchOutput get_frames_by_pts_in_range (
298363 at::Tensor& decoder,
299364 double start_seconds,
@@ -307,35 +372,22 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
307372OpsAudioFramesOutput get_frames_by_pts_in_range_audio (
308373 at::Tensor& decoder,
309374 double start_seconds,
310- std::optional<double > stop_seconds) {
375+ std::optional<double > stop_seconds = std:: nullopt ) {
311376 auto videoDecoder = unwrapTensorToGetDecoder (decoder);
312377 auto result =
313378 videoDecoder->getFramesPlayedInRangeAudio (start_seconds, stop_seconds);
314379 return makeOpsAudioFramesOutput (result);
315380}
316381
317- std::string quoteValue (const std::string& value) {
318- return " \" " + value + " \" " ;
319- }
320-
321- std::string mapToJson (const std::map<std::string, std::string>& metadataMap) {
322- std::stringstream ss;
323- ss << " {\n " ;
324- auto it = metadataMap.begin ();
325- while (it != metadataMap.end ()) {
326- ss << " \" " << it->first << " \" : " << it->second ;
327- ++it;
328- if (it != metadataMap.end ()) {
329- ss << " ,\n " ;
330- } else {
331- ss << " \n " ;
332- }
333- }
334- ss << " }" ;
335-
336- return ss.str ();
337- }
338-
382+ // For testing only. We need to implement this operation as a core library
383+ // function because what we're testing is round-tripping pts values as
384+ // double-precision floating point numbers from C++ to Python and back to C++.
385+ // We want to make sure that the value is preserved exactly, bit-for-bit, during
386+ // this process.
387+ //
388+ // Returns true if for the given decoder, the pts
389+ // value when converted to seconds as a double is exactly pts_seconds_to_test.
390+ // Returns false otherwise.
339391bool _test_frame_pts_equality (
340392 at::Tensor& decoder,
341393 int64_t frame_index,
@@ -350,6 +402,7 @@ torch::Tensor _get_key_frame_indices(at::Tensor& decoder) {
350402 return videoDecoder->getKeyFrameIndices ();
351403}
352404
405+ // Get the metadata from the video as a string.
353406std::string get_json_metadata (at::Tensor& decoder) {
354407 auto videoDecoder = unwrapTensorToGetDecoder (decoder);
355408
@@ -418,6 +471,7 @@ std::string get_json_metadata(at::Tensor& decoder) {
418471 return mapToJson (metadataMap);
419472}
420473
474+ // Get the container metadata as a string.
421475std::string get_container_json_metadata (at::Tensor& decoder) {
422476 auto videoDecoder = unwrapTensorToGetDecoder (decoder);
423477
@@ -448,6 +502,7 @@ std::string get_container_json_metadata(at::Tensor& decoder) {
448502 return mapToJson (map);
449503}
450504
505+ // Get the stream metadata as a string.
451506std::string get_stream_json_metadata (
452507 at::Tensor& decoder,
453508 int64_t stream_index) {
@@ -519,6 +574,8 @@ std::string get_stream_json_metadata(
519574 return mapToJson (map);
520575}
521576
577+ // Returns version information about the various FFMPEG libraries that are
578+ // loaded in the program's address space.
522579std::string _get_json_ffmpeg_library_versions () {
523580 std::stringstream ss;
524581 ss << " {\n " ;
@@ -545,6 +602,10 @@ std::string _get_json_ffmpeg_library_versions() {
545602 return ss.str ();
546603}
547604
605+ // Scans video packets to get more accurate metadata like frame count, exact
606+ // keyframe positions, etc. Exact keyframe positions are useful for efficient
607+ // accurate seeking. Note that this function reads the entire video but it does
608+ // not decode frames. Reading a video file is much cheaper than decoding it.
548609void scan_all_streams_to_update_metadata (at::Tensor& decoder) {
549610 auto videoDecoder = unwrapTensorToGetDecoder (decoder);
550611 videoDecoder->scanFileAndUpdateMetadataAndIndex ();
0 commit comments