Skip to content

Commit 169b4d6

Browse files
committed
Remove header for custom ops
1 parent a864bf9 commit 169b4d6

File tree

2 files changed

+115
-216
lines changed

2 files changed

+115
-216
lines changed

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 115 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
// This source code is licensed under the BSD-style license found in the
55
// LICENSE file in the root directory of this source tree.
66

7-
#include "src/torchcodec/decoders/_core/VideoDecoderOps.h"
87
#include <pybind11/pybind11.h>
98
#include <cstdint>
109
#include <sstream>
@@ -85,33 +84,82 @@ VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) {
8584
return decoder;
8685
}
8786

87+
// The elements of this tuple are all tensors that represent a single frame:
88+
// 1. The frame data, which is a multidimensional tensor.
89+
// 2. A single float value for the pts in seconds.
90+
// 3. A single float value for the duration in seconds.
91+
// The reason we use Tensors for the second and third values is so we can run
92+
// under torch.compile().
93+
using OpsFrameOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
94+
8895
OpsFrameOutput makeOpsFrameOutput(VideoDecoder::FrameOutput& frame) {
8996
return std::make_tuple(
9097
frame.data,
9198
torch::tensor(frame.ptsSeconds, torch::dtype(torch::kFloat64)),
9299
torch::tensor(frame.durationSeconds, torch::dtype(torch::kFloat64)));
93100
}
94101

102+
// All elements of this tuple are tensors of the same leading dimension. The
103+
// tuple represents the frames for N total frames, where N is the dimension of
104+
// each stacked tensor. The elments are:
105+
// 1. Stacked tensor of data for all N frames. Each frame is also a
106+
// multidimensional tensor.
107+
// 2. Tensor of N pts values in seconds, where each pts is a single
108+
// float.
109+
// 3. Tensor of N durationis in seconds, where each duration is a
110+
// single float.
111+
using OpsFrameBatchOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
112+
95113
OpsFrameBatchOutput makeOpsFrameBatchOutput(
96114
VideoDecoder::FrameBatchOutput& batch) {
97115
return std::make_tuple(batch.data, batch.ptsSeconds, batch.durationSeconds);
98116
}
99117

118+
// The elements of this tuple are all tensors that represent the concatenation
119+
// of multiple audio frames:
120+
// 1. The frames data (concatenated)
121+
// 2. A single float value for the pts of the first frame, in seconds.
122+
using OpsAudioFramesOutput = std::tuple<at::Tensor, at::Tensor>;
123+
100124
OpsAudioFramesOutput makeOpsAudioFramesOutput(
101125
VideoDecoder::AudioFramesOutput& audioFrames) {
102126
return std::make_tuple(
103127
audioFrames.data,
104128
torch::tensor(audioFrames.ptsSeconds, torch::dtype(torch::kFloat64)));
105129
}
130+
131+
std::string quoteValue(const std::string& value) {
132+
return "\"" + value + "\"";
133+
}
134+
135+
std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
136+
std::stringstream ss;
137+
ss << "{\n";
138+
auto it = metadataMap.begin();
139+
while (it != metadataMap.end()) {
140+
ss << "\"" << it->first << "\": " << it->second;
141+
++it;
142+
if (it != metadataMap.end()) {
143+
ss << ",\n";
144+
} else {
145+
ss << "\n";
146+
}
147+
}
148+
ss << "}";
149+
150+
return ss.str();
151+
}
152+
106153
} // namespace
107154

108155
// ==============================
109156
// Implementations for the operators
110157
// ==============================
111158

159+
// Create a VideoDecoder from file and wrap the pointer in a tensor.
112160
at::Tensor create_from_file(
113161
std::string_view filename,
114-
std::optional<std::string_view> seek_mode) {
162+
std::optional<std::string_view> seek_mode = std::nullopt) {
115163
std::string filenameStr(filename);
116164

117165
VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact;
@@ -125,9 +173,11 @@ at::Tensor create_from_file(
125173
return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
126174
}
127175

176+
// Create a VideoDecoder from the actual bytes of a video and wrap the pointer
177+
// in a tensor. The VideoDecoder will decode the provided bytes.
128178
at::Tensor create_from_tensor(
129179
at::Tensor video_tensor,
130-
std::optional<std::string_view> seek_mode) {
180+
std::optional<std::string_view> seek_mode = std::nullopt) {
131181
TORCH_CHECK(video_tensor.is_contiguous(), "video_tensor must be contiguous");
132182
TORCH_CHECK(
133183
video_tensor.scalar_type() == torch::kUInt8,
@@ -153,33 +203,15 @@ at::Tensor _convert_to_tensor(int64_t decoder_ptr) {
153203
return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
154204
}
155205

156-
void add_video_stream(
157-
at::Tensor& decoder,
158-
std::optional<int64_t> width,
159-
std::optional<int64_t> height,
160-
std::optional<int64_t> num_threads,
161-
std::optional<std::string_view> dimension_order,
162-
std::optional<int64_t> stream_index,
163-
std::optional<std::string_view> device) {
164-
_add_video_stream(
165-
decoder,
166-
width,
167-
height,
168-
num_threads,
169-
dimension_order,
170-
stream_index,
171-
device);
172-
}
173-
174206
void _add_video_stream(
175207
at::Tensor& decoder,
176-
std::optional<int64_t> width,
177-
std::optional<int64_t> height,
178-
std::optional<int64_t> num_threads,
179-
std::optional<std::string_view> dimension_order,
180-
std::optional<int64_t> stream_index,
181-
std::optional<std::string_view> device,
182-
std::optional<std::string_view> color_conversion_library) {
208+
std::optional<int64_t> width = std::nullopt,
209+
std::optional<int64_t> height = std::nullopt,
210+
std::optional<int64_t> num_threads = std::nullopt,
211+
std::optional<std::string_view> dimension_order = std::nullopt,
212+
std::optional<int64_t> stream_index = std::nullopt,
213+
std::optional<std::string_view> device = std::nullopt,
214+
std::optional<std::string_view> color_conversion_library = std::nullopt) {
183215
VideoDecoder::VideoStreamOptions videoStreamOptions;
184216
videoStreamOptions.width = width;
185217
videoStreamOptions.height = height;
@@ -221,22 +253,44 @@ void _add_video_stream(
221253
videoDecoder->addVideoStream(stream_index.value_or(-1), videoStreamOptions);
222254
}
223255

256+
// Add a new video stream at `stream_index` using the provided options.
257+
void add_video_stream(
258+
at::Tensor& decoder,
259+
std::optional<int64_t> width = std::nullopt,
260+
std::optional<int64_t> height = std::nullopt,
261+
std::optional<int64_t> num_threads = std::nullopt,
262+
std::optional<std::string_view> dimension_order = std::nullopt,
263+
std::optional<int64_t> stream_index = std::nullopt,
264+
std::optional<std::string_view> device = std::nullopt) {
265+
_add_video_stream(
266+
decoder,
267+
width,
268+
height,
269+
num_threads,
270+
dimension_order,
271+
stream_index,
272+
device);
273+
}
274+
224275
void add_audio_stream(
225276
at::Tensor& decoder,
226-
std::optional<int64_t> stream_index,
227-
std::optional<int64_t> sample_rate) {
277+
std::optional<int64_t> stream_index = std::nullopt,
278+
std::optional<int64_t> sample_rate = std::nullopt) {
228279
VideoDecoder::AudioStreamOptions audioStreamOptions;
229280
audioStreamOptions.sampleRate = sample_rate;
230281

231282
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
232283
videoDecoder->addAudioStream(stream_index.value_or(-1), audioStreamOptions);
233284
}
234285

286+
// Seek to a particular presentation timestamp in the video in seconds.
235287
void seek_to_pts(at::Tensor& decoder, double seconds) {
236288
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
237289
videoDecoder->setCursorPtsInSeconds(seconds);
238290
}
239291

292+
// Get the next frame from the video as a tuple that has the frame data, pts and
293+
// duration as tensors.
240294
OpsFrameOutput get_next_frame(at::Tensor& decoder) {
241295
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
242296
VideoDecoder::FrameOutput result;
@@ -248,6 +302,9 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder) {
248302
return makeOpsFrameOutput(result);
249303
}
250304

305+
// Return the frame that is visible at a given timestamp in seconds. Each frame
306+
// in FFMPEG has a presentation timestamp and a duration. The frame visible at a
307+
// given timestamp T has T >= PTS and T < PTS + Duration.
251308
OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
252309
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
253310
VideoDecoder::FrameOutput result;
@@ -259,12 +316,14 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
259316
return makeOpsFrameOutput(result);
260317
}
261318

319+
// Return the frame that is visible at a given index in the video.
262320
OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) {
263321
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
264322
auto result = videoDecoder->getFrameAtIndex(frame_index);
265323
return makeOpsFrameOutput(result);
266324
}
267325

326+
// Return the frames at given indices for a given stream
268327
OpsFrameBatchOutput get_frames_at_indices(
269328
at::Tensor& decoder,
270329
at::IntArrayRef frame_indices) {
@@ -275,16 +334,19 @@ OpsFrameBatchOutput get_frames_at_indices(
275334
return makeOpsFrameBatchOutput(result);
276335
}
277336

337+
// Return the frames inside a range as a single stacked Tensor. The range is
338+
// defined as [start, stop).
278339
OpsFrameBatchOutput get_frames_in_range(
279340
at::Tensor& decoder,
280341
int64_t start,
281342
int64_t stop,
282-
std::optional<int64_t> step) {
343+
std::optional<int64_t> step = std::nullopt) {
283344
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
284345
auto result = videoDecoder->getFramesInRange(start, stop, step.value_or(1));
285346
return makeOpsFrameBatchOutput(result);
286347
}
287348

349+
// Return the frames at given ptss for a given stream
288350
OpsFrameBatchOutput get_frames_by_pts(
289351
at::Tensor& decoder,
290352
at::ArrayRef<double> timestamps) {
@@ -294,6 +356,9 @@ OpsFrameBatchOutput get_frames_by_pts(
294356
return makeOpsFrameBatchOutput(result);
295357
}
296358

359+
// Return the frames inside the range as a single stacked Tensor. The range is
360+
// defined as [start_seconds, stop_seconds). The frames are stacked in pts
361+
// order.
297362
OpsFrameBatchOutput get_frames_by_pts_in_range(
298363
at::Tensor& decoder,
299364
double start_seconds,
@@ -307,35 +372,22 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
307372
OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
308373
at::Tensor& decoder,
309374
double start_seconds,
310-
std::optional<double> stop_seconds) {
375+
std::optional<double> stop_seconds = std::nullopt) {
311376
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
312377
auto result =
313378
videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds);
314379
return makeOpsAudioFramesOutput(result);
315380
}
316381

317-
std::string quoteValue(const std::string& value) {
318-
return "\"" + value + "\"";
319-
}
320-
321-
std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
322-
std::stringstream ss;
323-
ss << "{\n";
324-
auto it = metadataMap.begin();
325-
while (it != metadataMap.end()) {
326-
ss << "\"" << it->first << "\": " << it->second;
327-
++it;
328-
if (it != metadataMap.end()) {
329-
ss << ",\n";
330-
} else {
331-
ss << "\n";
332-
}
333-
}
334-
ss << "}";
335-
336-
return ss.str();
337-
}
338-
382+
// For testing only. We need to implement this operation as a core library
383+
// function because what we're testing is round-tripping pts values as
384+
// double-precision floating point numbers from C++ to Python and back to C++.
385+
// We want to make sure that the value is preserved exactly, bit-for-bit, during
386+
// this process.
387+
//
388+
// Returns true if for the given decoder, the pts
389+
// value when converted to seconds as a double is exactly pts_seconds_to_test.
390+
// Returns false otherwise.
339391
bool _test_frame_pts_equality(
340392
at::Tensor& decoder,
341393
int64_t frame_index,
@@ -350,6 +402,7 @@ torch::Tensor _get_key_frame_indices(at::Tensor& decoder) {
350402
return videoDecoder->getKeyFrameIndices();
351403
}
352404

405+
// Get the metadata from the video as a string.
353406
std::string get_json_metadata(at::Tensor& decoder) {
354407
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
355408

@@ -418,6 +471,7 @@ std::string get_json_metadata(at::Tensor& decoder) {
418471
return mapToJson(metadataMap);
419472
}
420473

474+
// Get the container metadata as a string.
421475
std::string get_container_json_metadata(at::Tensor& decoder) {
422476
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
423477

@@ -448,6 +502,7 @@ std::string get_container_json_metadata(at::Tensor& decoder) {
448502
return mapToJson(map);
449503
}
450504

505+
// Get the stream metadata as a string.
451506
std::string get_stream_json_metadata(
452507
at::Tensor& decoder,
453508
int64_t stream_index) {
@@ -519,6 +574,8 @@ std::string get_stream_json_metadata(
519574
return mapToJson(map);
520575
}
521576

577+
// Returns version information about the various FFMPEG libraries that are
578+
// loaded in the program's address space.
522579
std::string _get_json_ffmpeg_library_versions() {
523580
std::stringstream ss;
524581
ss << "{\n";
@@ -545,6 +602,10 @@ std::string _get_json_ffmpeg_library_versions() {
545602
return ss.str();
546603
}
547604

605+
// Scans video packets to get more accurate metadata like frame count, exact
606+
// keyframe positions, etc. Exact keyframe positions are useful for efficient
607+
// accurate seeking. Note that this function reads the entire video but it does
608+
// not decode frames. Reading a video file is much cheaper than decoding it.
548609
void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
549610
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
550611
videoDecoder->scanFileAndUpdateMetadataAndIndex();

0 commit comments

Comments
 (0)