Skip to content

Commit 6601d50

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into readme_decord
2 parents 5937eac + 39463b8 commit 6601d50

File tree

5 files changed

+135
-45
lines changed

5 files changed

+135
-45
lines changed

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace facebook::torchcodec {
1717
void convertAVFrameToDecodedOutputOnCuda(
1818
const torch::Device& device,
1919
const VideoDecoder::VideoStreamDecoderOptions& options,
20-
AVCodecContext* codecContext,
20+
const VideoDecoder::StreamMetadata& metadata,
2121
VideoDecoder::RawDecodedOutput& rawOutput,
2222
VideoDecoder::DecodedOutput& output,
2323
std::optional<torch::Tensor> preAllocatedOutputTensor) {

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -154,18 +154,6 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
154154
#endif
155155
}
156156

157-
torch::Tensor allocateDeviceTensor(
158-
at::IntArrayRef shape,
159-
torch::Device device,
160-
const torch::Dtype dtype = torch::kUInt8) {
161-
return torch::empty(
162-
shape,
163-
torch::TensorOptions()
164-
.dtype(dtype)
165-
.layout(torch::kStrided)
166-
.device(device));
167-
}
168-
169157
void throwErrorIfNonCudaDevice(const torch::Device& device) {
170158
TORCH_CHECK(
171159
device.type() != torch::kCPU,
@@ -199,7 +187,7 @@ void initializeContextOnCuda(
199187
void convertAVFrameToDecodedOutputOnCuda(
200188
const torch::Device& device,
201189
const VideoDecoder::VideoStreamDecoderOptions& options,
202-
AVCodecContext* codecContext,
190+
const VideoDecoder::StreamMetadata& metadata,
203191
VideoDecoder::RawDecodedOutput& rawOutput,
204192
VideoDecoder::DecodedOutput& output,
205193
std::optional<torch::Tensor> preAllocatedOutputTensor) {
@@ -209,8 +197,9 @@ void convertAVFrameToDecodedOutputOnCuda(
209197
src->format == AV_PIX_FMT_CUDA,
210198
"Expected format to be AV_PIX_FMT_CUDA, got " +
211199
std::string(av_get_pix_fmt_name((AVPixelFormat)src->format)));
212-
int width = options.width.value_or(codecContext->width);
213-
int height = options.height.value_or(codecContext->height);
200+
auto frameDims = getHeightAndWidthFromOptionsOrMetadata(options, metadata);
201+
int height = frameDims.height;
202+
int width = frameDims.width;
214203
NppiSize oSizeROI = {width, height};
215204
Npp8u* input[2] = {src->data[0], src->data[1]};
216205
torch::Tensor& dst = output.frame;
@@ -227,7 +216,7 @@ void convertAVFrameToDecodedOutputOnCuda(
227216
"x3, got ",
228217
shape);
229218
} else {
230-
dst = allocateDeviceTensor({height, width, 3}, options.device);
219+
dst = allocateEmptyHWCTensor(height, width, options.device);
231220
}
232221

233222
// Use the user-requested GPU for running the NPP kernel.

src/torchcodec/decoders/_core/DeviceInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void initializeContextOnCuda(
3535
void convertAVFrameToDecodedOutputOnCuda(
3636
const torch::Device& device,
3737
const VideoDecoder::VideoStreamDecoderOptions& options,
38-
AVCodecContext* codecContext,
38+
const VideoDecoder::StreamMetadata& metadata,
3939
VideoDecoder::RawDecodedOutput& rawOutput,
4040
VideoDecoder::DecodedOutput& output,
4141
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 64 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
195195
int64_t numFrames,
196196
const VideoStreamDecoderOptions& options,
197197
const StreamMetadata& metadata)
198-
: frames(torch::empty(
199-
{numFrames,
200-
options.height.value_or(*metadata.height),
201-
options.width.value_or(*metadata.width),
202-
3},
203-
at::TensorOptions(options.device).dtype(torch::kUInt8))),
204-
ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
205-
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {}
198+
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
199+
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {
200+
auto frameDims = getHeightAndWidthFromOptionsOrMetadata(options, metadata);
201+
int height = frameDims.height;
202+
int width = frameDims.width;
203+
frames = allocateEmptyHWCTensor(height, width, options.device, numFrames);
204+
}
206205

207206
VideoDecoder::VideoDecoder() {}
208207

@@ -364,12 +363,11 @@ void VideoDecoder::initializeFilterGraphForStream(
364363
inputs->pad_idx = 0;
365364
inputs->next = nullptr;
366365
char description[512];
367-
int width = activeStream.codecContext->width;
368-
int height = activeStream.codecContext->height;
369-
if (options.height.has_value() && options.width.has_value()) {
370-
width = *options.width;
371-
height = *options.height;
372-
}
366+
auto frameDims = getHeightAndWidthFromOptionsOrMetadata(
367+
options, containerMetadata_.streams[streamIndex]);
368+
int height = frameDims.height;
369+
int width = frameDims.width;
370+
373371
std::snprintf(
374372
description,
375373
sizeof(description),
@@ -835,10 +833,6 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
835833
StreamInfo& activeStream = streams_[frameStreamIndex];
836834
activeStream.currentPts = frame->pts;
837835
activeStream.currentDuration = getDuration(frame);
838-
auto startToSeekDone =
839-
std::chrono::duration_cast<std::chrono::milliseconds>(seekDone - start);
840-
auto seekToDecodeDone = std::chrono::duration_cast<std::chrono::milliseconds>(
841-
decodeDone - seekDone);
842836
RawDecodedOutput rawOutput;
843837
rawOutput.streamIndex = frameStreamIndex;
844838
rawOutput.frame = std::move(frame);
@@ -869,7 +863,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
869863
convertAVFrameToDecodedOutputOnCuda(
870864
streamInfo.options.device,
871865
streamInfo.options,
872-
streamInfo.codecContext.get(),
866+
containerMetadata_.streams[streamIndex],
873867
rawOutput,
874868
output,
875869
preAllocatedOutputTensor);
@@ -899,8 +893,10 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
899893
torch::Tensor tensor;
900894
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
901895
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
902-
int width = streamInfo.options.width.value_or(frame->width);
903-
int height = streamInfo.options.height.value_or(frame->height);
896+
auto frameDims =
897+
getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *frame);
898+
int height = frameDims.height;
899+
int width = frameDims.width;
904900
if (preAllocatedOutputTensor.has_value()) {
905901
tensor = preAllocatedOutputTensor.value();
906902
auto shape = tensor.sizes();
@@ -914,8 +910,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
914910
"x3, got ",
915911
shape);
916912
} else {
917-
tensor = torch::empty(
918-
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
913+
tensor = allocateEmptyHWCTensor(height, width, torch::kCPU);
919914
}
920915
rawOutput.data = tensor.data_ptr<uint8_t>();
921916
convertFrameToBufferUsingSwsScale(rawOutput);
@@ -1315,8 +1310,10 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale(
13151310
enum AVPixelFormat frameFormat =
13161311
static_cast<enum AVPixelFormat>(frame->format);
13171312
StreamInfo& activeStream = streams_[streamIndex];
1318-
int outputWidth = activeStream.options.width.value_or(frame->width);
1319-
int outputHeight = activeStream.options.height.value_or(frame->height);
1313+
auto frameDims =
1314+
getHeightAndWidthFromOptionsOrAVFrame(activeStream.options, *frame);
1315+
int outputHeight = frameDims.height;
1316+
int outputWidth = frameDims.width;
13201317
if (activeStream.swsContext.get() == nullptr) {
13211318
SwsContext* swsContext = sws_getContext(
13221319
frame->width,
@@ -1382,15 +1379,18 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
13821379
ffmpegStatus =
13831380
av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get());
13841381
TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24);
1385-
std::vector<int64_t> shape = {filteredFrame->height, filteredFrame->width, 3};
1382+
auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(
1383+
streams_[streamIndex].options, *filteredFrame.get());
1384+
int height = frameDims.height;
1385+
int width = frameDims.width;
1386+
std::vector<int64_t> shape = {height, width, 3};
13861387
std::vector<int64_t> strides = {filteredFrame->linesize[0], 3, 1};
13871388
AVFrame* filteredFramePtr = filteredFrame.release();
13881389
auto deleter = [filteredFramePtr](void*) {
13891390
UniqueAVFrame frameToDelete(filteredFramePtr);
13901391
};
13911392
torch::Tensor tensor = torch::from_blob(
13921393
filteredFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
1393-
StreamInfo& activeStream = streams_[streamIndex];
13941394
return tensor;
13951395
}
13961396

@@ -1406,6 +1406,43 @@ VideoDecoder::~VideoDecoder() {
14061406
}
14071407
}
14081408

1409+
FrameDims getHeightAndWidthFromOptionsOrMetadata(
1410+
const VideoDecoder::VideoStreamDecoderOptions& options,
1411+
const VideoDecoder::StreamMetadata& metadata) {
1412+
return FrameDims(
1413+
options.height.value_or(*metadata.height),
1414+
options.width.value_or(*metadata.width));
1415+
}
1416+
1417+
FrameDims getHeightAndWidthFromOptionsOrAVFrame(
1418+
const VideoDecoder::VideoStreamDecoderOptions& options,
1419+
const AVFrame& avFrame) {
1420+
return FrameDims(
1421+
options.height.value_or(avFrame.height),
1422+
options.width.value_or(avFrame.width));
1423+
}
1424+
1425+
torch::Tensor allocateEmptyHWCTensor(
1426+
int height,
1427+
int width,
1428+
torch::Device device,
1429+
std::optional<int> numFrames) {
1430+
auto tensorOptions = torch::TensorOptions()
1431+
.dtype(torch::kUInt8)
1432+
.layout(torch::kStrided)
1433+
.device(device);
1434+
TORCH_CHECK(height > 0, "height must be > 0, got: ", height);
1435+
TORCH_CHECK(width > 0, "width must be > 0, got: ", width);
1436+
if (numFrames.has_value()) {
1437+
auto numFramesValue = numFrames.value();
1438+
TORCH_CHECK(
1439+
numFramesValue >= 0, "numFrames must be >= 0, got: ", numFramesValue);
1440+
return torch::empty({numFramesValue, height, width, 3}, tensorOptions);
1441+
} else {
1442+
return torch::empty({height, width, 3}, tensorOptions);
1443+
}
1444+
}
1445+
14091446
std::ostream& operator<<(
14101447
std::ostream& os,
14111448
const VideoDecoder::DecodeStats& stats) {

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ class VideoDecoder {
243243
const VideoStreamDecoderOptions& options,
244244
const StreamMetadata& metadata);
245245
};
246+
246247
// Returns frames at the given indices for a given stream as a single stacked
247248
// Tensor.
248249
BatchDecodedOutput getFramesAtIndices(
@@ -413,6 +414,69 @@ class VideoDecoder {
413414
bool scanned_all_streams_ = false;
414415
};
415416

417+
// --------------------------------------------------------------------------
418+
// FRAME TENSOR ALLOCATION APIs
419+
// --------------------------------------------------------------------------
420+
421+
// Note [Frame Tensor allocation and height and width]
422+
//
423+
// We always allocate [N]HWC tensors. The low-level decoding functions all
424+
// assume HWC tensors, since this is what FFmpeg natively handles. It's up to
425+
// the high-level decoding entry-points to permute that back to CHW, by calling
426+
// MaybePermuteHWC2CHW().
427+
//
428+
// Also, importantly, the way we figure out the the height and width of the
429+
// output frame varies and depends on the decoding entry-point:
430+
// - In all cases, if the user requested specific height and width from the
431+
// options, we honor that. Otherwise we fall into one of the categories below.
432+
// - In Batch decoding APIs (e.g. getFramesAtIndices), we get height and width
433+
// from the stream metadata, which itself got its value from the CodecContext,
434+
// when the stream was added.
435+
// - In single frames APIs:
436+
// - On CPU we get height and width from the AVFrame.
437+
// - On GPU, we get height and width from the metadata (same as batch APIs)
438+
//
439+
// These 2 strategies are encapsulated within
440+
// getHeightAndWidthFromOptionsOrMetadata() and
441+
// getHeightAndWidthFromOptionsOrAVFrame(). The reason they exist is to make it
442+
// very obvious which logic is used in which place, and they allow for `git
443+
// grep`ing.
444+
//
445+
// The source of truth for height and width really is the AVFrame: it's the
446+
// decoded ouptut from FFmpeg. The info from the metadata (i.e. from the
447+
// CodecContext) may not be as accurate. However, the AVFrame is only available
448+
// late in the call stack, when the frame is decoded, while the CodecContext is
449+
// available early when a stream is added. This is why we use the CodecContext
450+
// for pre-allocating batched output tensors (we could pre-allocate those only
451+
// once we decode the first frame to get the info frame the AVFrame, but that's
452+
// a more complex logic).
453+
//
454+
// Because the sources for height and width may disagree, we may end up with
455+
// conflicts: e.g. if we pre-allocate a batch output tensor based on the
456+
// metadata info, but the decoded AVFrame has a different height and width.
457+
// it is very important to check the height and width assumptions where the
458+
// tensors memory is used/filled in order to avoid segfaults.
459+
460+
struct FrameDims {
461+
int height;
462+
int width;
463+
FrameDims(int h, int w) : height(h), width(w) {}
464+
};
465+
466+
FrameDims getHeightAndWidthFromOptionsOrMetadata(
467+
const VideoDecoder::VideoStreamDecoderOptions& options,
468+
const VideoDecoder::StreamMetadata& metadata);
469+
470+
FrameDims getHeightAndWidthFromOptionsOrAVFrame(
471+
const VideoDecoder::VideoStreamDecoderOptions& options,
472+
const AVFrame& avFrame);
473+
474+
torch::Tensor allocateEmptyHWCTensor(
475+
int height,
476+
int width,
477+
torch::Device device,
478+
std::optional<int> numFrames = std::nullopt);
479+
416480
// Prints the VideoDecoder::DecodeStats to the ostream.
417481
std::ostream& operator<<(
418482
std::ostream& os,

0 commit comments

Comments
 (0)