Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,8 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
}

VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
VideoDecoder::RawDecodedOutput& rawOutput) {
VideoDecoder::RawDecodedOutput& rawOutput,
torch::Tensor& preAllocatedOutputTensor) {
// Convert the frame to tensor.
DecodedOutput output;
int streamIndex = rawOutput.streamIndex;
Expand All @@ -861,8 +862,10 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
output.durationSeconds = ptsToSeconds(
getDuration(frame), formatContext_->streams[streamIndex]->time_base);
if (streamInfo.options.device.type() == torch::kCPU) {
convertAVFrameToDecodedOutputOnCPU(rawOutput, output);
convertAVFrameToDecodedOutputOnCPU(
rawOutput, output, preAllocatedOutputTensor);
} else if (streamInfo.options.device.type() == torch::kCUDA) {
// TODO: handle pre-allocated output tensor
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is a no-op for CUDA devices. I'm leaving-out CUDA pre-allocation because this is strongly tied to #189 and can be treated separately.

convertAVFrameToDecodedOutputOnCuda(
streamInfo.options.device,
streamInfo.options,
Expand All @@ -878,16 +881,24 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(

void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
VideoDecoder::RawDecodedOutput& rawOutput,
DecodedOutput& output) {
DecodedOutput& output,
torch::Tensor& preAllocatedOutputTensor) {
int streamIndex = rawOutput.streamIndex;
AVFrame* frame = rawOutput.frame.get();
auto& streamInfo = streams_[streamIndex];
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
int width = streamInfo.options.width.value_or(frame->width);
int height = streamInfo.options.height.value_or(frame->height);
torch::Tensor tensor = torch::empty(
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
torch::Tensor tensor;
if (preAllocatedOutputTensor.numel() != 0) {
// TODO: check shape of preAllocatedOutputTensor?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should TORCH_CHECK for height, width, shape, etc. here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have this a try thinking it would be a simple assert like

assert `shape[-3] == H, W, 3`

But it turns out it's not as simple. Some tensors come as HWC while some other come pas HWC. This is because the pre-allocated batched tensors are allocated as such:

https://github.com/pytorch/torchcodec/blob/c6a0a5a079da408df56fd7c06e3b801cbada4db1/src/torchcodec/decoders/_core/VideoDecoder.cpp#L171-L186

It then me realize that everything works, but it's pretty magical. We end up doing the .pemute() calls in different places, but I think it would be a lot cleaner if we allocated batched output only as NHWC, and then permute this entire NHWC tensor in one go. What we do right now is that we permute all the N HWC tensors, and that's probably not as efficient (or as clean).

I want to fix this as an immediate follow-up if that's OK. I gave it a try here, but it's not trivial and it might be preferable not to overcomplexify this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm in favor of @NicolasHug suggestion. The logic he points out is legacy from way back when, and it wasn't necessarily throught through in terms of long term maintenance and code health. Always doing it one way, and then permuting as needed on the way out, sounds easier and cleaner.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me

tensor = preAllocatedOutputTensor;
} else {
int width = streamInfo.options.width.value_or(frame->width);
int height = streamInfo.options.height.value_or(frame->height);
tensor = torch::empty(
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
}

rawOutput.data = tensor.data_ptr<uint8_t>();
convertFrameToBufferUsingSwsScale(rawOutput);

Expand Down Expand Up @@ -945,7 +956,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux(
return seconds >= frameStartTime && seconds < frameEndTime;
});
// Convert the frame to tensor.
return convertAVFrameToDecodedOutput(rawOutput);
auto preAllocatedOutputTensor = torch::empty({0});
return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor);
}

void VideoDecoder::validateUserProvidedStreamIndex(uint64_t streamIndex) {
Expand Down Expand Up @@ -980,7 +992,8 @@ void VideoDecoder::validateFrameIndex(

VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
int streamIndex,
int64_t frameIndex) {
int64_t frameIndex,
torch::Tensor& preAllocatedOutputTensor) {
validateUserProvidedStreamIndex(streamIndex);
validateScannedAllStreams("getFrameAtIndex");

Expand All @@ -989,7 +1002,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(

int64_t pts = stream.allFrames[frameIndex].pts;
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
return getNextDecodedOutputNoDemux();
return getNextDecodedOutputNoDemux(preAllocatedOutputTensor);
}

VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
Expand Down Expand Up @@ -1061,8 +1074,9 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
BatchDecodedOutput output(numOutputFrames, options, streamMetadata);

for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i);
output.frames[f] = singleOut.frame;
auto preAllocatedOutputTensor = output.frames[f];
DecodedOutput singleOut =
getFrameAtIndex(streamIndex, i, preAllocatedOutputTensor);
output.ptsSeconds[f] = singleOut.ptsSeconds;
output.durationSeconds[f] = singleOut.durationSeconds;
}
Expand Down Expand Up @@ -1154,8 +1168,9 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
int64_t numFrames = stopFrameIndex - startFrameIndex;
BatchDecodedOutput output(numFrames, options, streamMetadata);
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i);
output.frames[f] = singleOut.frame;
auto preAllocatedOutputTensor = output.frames[f];
DecodedOutput singleOut =
getFrameAtIndex(streamIndex, i, preAllocatedOutputTensor);
output.ptsSeconds[f] = singleOut.ptsSeconds;
output.durationSeconds[f] = singleOut.durationSeconds;
}
Expand All @@ -1174,8 +1189,13 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
}

VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() {
auto preAllocatedOutputTensor = torch::empty({0});
return VideoDecoder::getNextDecodedOutputNoDemux(preAllocatedOutputTensor);
}
VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux(
torch::Tensor& preAllocatedOutputTensor) {
auto rawOutput = getNextRawDecodedOutputNoDemux();
return convertAVFrameToDecodedOutput(rawOutput);
return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor);
}

void VideoDecoder::setCursorPtsInSeconds(double seconds) {
Expand Down
14 changes: 11 additions & 3 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,19 @@ class VideoDecoder {
// Decodes the frame where the current cursor position is. It also advances
// the cursor to the next frame.
DecodedOutput getNextDecodedOutputNoDemux();
DecodedOutput getNextDecodedOutputNoDemux(
torch::Tensor& preAllocatedOutputTensor);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I had to overload this one (only), because it's called in a ton of places in the C++ tests. Forcing to pass an empty tensor at all call-sites would be quite noisy

// Decodes the first frame in any added stream that is visible at a given
// timestamp. Frames in the video have a presentation timestamp and a
// duration. For example, if a frame has presentation timestamp of 5.0s and a
// duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0).
// i.e. it will be returned when this function is called with seconds=5.0 or
// seconds=5.999, etc.
DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds);
DecodedOutput getFrameAtIndex(int streamIndex, int64_t frameIndex);
DecodedOutput getFrameAtIndex(
int streamIndex,
int64_t frameIndex,
torch::Tensor& preAllocatedOutputTensor);
struct BatchDecodedOutput {
torch::Tensor frames;
torch::Tensor ptsSeconds;
Expand Down Expand Up @@ -363,10 +368,13 @@ class VideoDecoder {
int streamIndex,
const AVFrame* frame);
void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput);
DecodedOutput convertAVFrameToDecodedOutput(RawDecodedOutput& rawOutput);
DecodedOutput convertAVFrameToDecodedOutput(
RawDecodedOutput& rawOutput,
torch::Tensor& preAllocatedOutputTensor);
void convertAVFrameToDecodedOutputOnCPU(
RawDecodedOutput& rawOutput,
DecodedOutput& output);
DecodedOutput& output,
torch::Tensor& preAllocatedOutputTensor);

DecoderOptions options_;
ContainerMetadata containerMetadata_;
Expand Down
4 changes: 3 additions & 1 deletion src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ OpsDecodedOutput get_frame_at_index(
int64_t stream_index,
int64_t frame_index) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index);
auto preAllocatedOutputTensor = torch::empty({0});
auto result = videoDecoder->getFrameAtIndex(
stream_index, frame_index, preAllocatedOutputTensor);
return makeOpsDecodedOutput(result);
}

Expand Down
Loading