Skip to content

Commit c1a7f68

Browse files
committed
Merge remote-tracking branch 'origin/main' into cuda8
2 parents 425336b + c91e33e commit c1a7f68

File tree

5 files changed

+163
-79
lines changed

5 files changed

+163
-79
lines changed

.github/workflows/cuda_tests.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ jobs:
1616
python_version: ["3.9"]
1717
# TODO: Add more cuda versions.
1818
cuda_arch_version: ["12.4"]
19-
ffmpeg_version: ["origin/release/6.1"]
19+
# TODO: Get ffmpeg 4 to work. Currently fails to build with nvcc.
20+
ffmpeg_version: ["origin/release/5.1", "origin/release/6.1", "origin/release/7.1"]
2021
fail-fast: false
2122
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
2223
with:

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ AVBufferRef* getFromCache(const torch::Device& device) {
7777
return nullptr;
7878
}
7979

80+
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
81+
8082
AVBufferRef* getFFMPEGContextFromExistingCudaContext(
8183
const torch::Device& device,
8284
torch::DeviceIndex nonNegativeDeviceIndex,
@@ -105,6 +107,8 @@ AVBufferRef* getFFMPEGContextFromExistingCudaContext(
105107
return hw_device_ctx;
106108
}
107109

110+
#else
111+
108112
AVBufferRef* getFFMPEGContextFromNewCudaContext(
109113
const torch::Device& device,
110114
torch::DeviceIndex nonNegativeDeviceIndex,
@@ -122,6 +126,8 @@ AVBufferRef* getFFMPEGContextFromNewCudaContext(
122126
return hw_device_ctx;
123127
}
124128

129+
#endif
130+
125131
AVBufferRef* getCudaContext(const torch::Device& device) {
126132
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
127133
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 96 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,31 @@ double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
3434
return ptsToSeconds(pts, timeBase.den);
3535
}
3636

37+
// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
38+
// The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
39+
// or 4D.
40+
// Calling permute() is guaranteed to return a view as per the docs:
41+
// https://pytorch.org/docs/stable/generated/torch.permute.html
42+
torch::Tensor MaybePermuteHWC2CHW(
43+
const VideoDecoder::VideoStreamDecoderOptions& options,
44+
torch::Tensor& hwcTensor) {
45+
if (options.dimensionOrder == "NHWC") {
46+
return hwcTensor;
47+
}
48+
auto numDimensions = hwcTensor.dim();
49+
auto shape = hwcTensor.sizes();
50+
if (numDimensions == 3) {
51+
TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape);
52+
return hwcTensor.permute({2, 0, 1});
53+
} else if (numDimensions == 4) {
54+
TORCH_CHECK(shape[3] == 3, "Not a NHWC tensor: ", shape);
55+
return hwcTensor.permute({0, 3, 1, 2});
56+
} else {
57+
TORCH_CHECK(
58+
false, "Expected tensor with 3 or 4 dimensions, got ", numDimensions);
59+
}
60+
}
61+
3762
struct AVInput {
3863
UniqueAVFormatContext formatContext;
3964
std::unique_ptr<AVIOBytesContext> ioBytesContext;
@@ -167,28 +192,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
167192
const VideoStreamDecoderOptions& options,
168193
const StreamMetadata& metadata)
169194
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
170-
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {
171-
if (options.dimensionOrder == "NHWC") {
172-
frames = torch::empty(
173-
{numFrames,
174-
options.height.value_or(*metadata.height),
175-
options.width.value_or(*metadata.width),
176-
3},
177-
{torch::kUInt8});
178-
} else if (options.dimensionOrder == "NCHW") {
179-
frames = torch::empty(
180-
{numFrames,
181-
3,
182-
options.height.value_or(*metadata.height),
183-
options.width.value_or(*metadata.width)},
184-
torch::TensorOptions()
185-
.memory_format(torch::MemoryFormat::ChannelsLast)
186-
.dtype({torch::kUInt8}));
187-
} else {
188-
TORCH_CHECK(
189-
false, "Unsupported frame dimensionOrder =" + options.dimensionOrder)
190-
}
191-
}
195+
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})),
196+
frames(torch::empty(
197+
{numFrames,
198+
options.height.value_or(*metadata.height),
199+
options.width.value_or(*metadata.width),
200+
3},
201+
{torch::kUInt8})) {}
192202

193203
VideoDecoder::VideoDecoder() {}
194204

@@ -652,8 +662,9 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
652662
}
653663
for (int streamIndex : activeStreamIndices_) {
654664
StreamInfo& streamInfo = streams_[streamIndex];
655-
streamInfo.discardFramesBeforePts =
656-
*maybeDesiredPts_ * streamInfo.timeBase.den;
665+
// clang-format off: clang format clashes
666+
streamInfo.discardFramesBeforePts = *maybeDesiredPts_ * streamInfo.timeBase.den;
667+
// clang-format on
657668
}
658669

659670
decodeStats_.numSeeksAttempted++;
@@ -846,7 +857,8 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
846857
}
847858

848859
VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
849-
VideoDecoder::RawDecodedOutput& rawOutput) {
860+
VideoDecoder::RawDecodedOutput& rawOutput,
861+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
850862
// Convert the frame to tensor.
851863
DecodedOutput output;
852864
int streamIndex = rawOutput.streamIndex;
@@ -861,8 +873,10 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
861873
output.durationSeconds = ptsToSeconds(
862874
getDuration(frame), formatContext_->streams[streamIndex]->time_base);
863875
if (streamInfo.options.device.type() == torch::kCPU) {
864-
convertAVFrameToDecodedOutputOnCPU(rawOutput, output);
876+
convertAVFrameToDecodedOutputOnCPU(
877+
rawOutput, output, preAllocatedOutputTensor);
865878
} else if (streamInfo.options.device.type() == torch::kCUDA) {
879+
// TODO: handle pre-allocated output tensor
866880
convertAVFrameToDecodedOutputOnCuda(
867881
streamInfo.options.device,
868882
streamInfo.options,
@@ -878,22 +892,35 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
878892

879893
void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
880894
VideoDecoder::RawDecodedOutput& rawOutput,
881-
DecodedOutput& output) {
895+
DecodedOutput& output,
896+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
882897
int streamIndex = rawOutput.streamIndex;
883898
AVFrame* frame = rawOutput.frame.get();
884899
auto& streamInfo = streams_[streamIndex];
885900
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
886901
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
902+
torch::Tensor tensor;
887903
int width = streamInfo.options.width.value_or(frame->width);
888904
int height = streamInfo.options.height.value_or(frame->height);
889-
torch::Tensor tensor = torch::empty(
890-
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
905+
if (preAllocatedOutputTensor.has_value()) {
906+
tensor = preAllocatedOutputTensor.value();
907+
auto shape = tensor.sizes();
908+
TORCH_CHECK(
909+
(shape.size() == 3) && (shape[0] == height) &&
910+
(shape[1] == width) && (shape[2] == 3),
911+
"Expected tensor of shape ",
912+
height,
913+
"x",
914+
width,
915+
"x3, got ",
916+
shape);
917+
} else {
918+
tensor = torch::empty(
919+
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
920+
}
891921
rawOutput.data = tensor.data_ptr<uint8_t>();
892922
convertFrameToBufferUsingSwsScale(rawOutput);
893923

894-
if (streamInfo.options.dimensionOrder == "NCHW") {
895-
tensor = tensor.permute({2, 0, 1});
896-
}
897924
output.frame = tensor;
898925
} else if (
899926
streamInfo.colorConversionLibrary ==
@@ -904,6 +931,14 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
904931
"Invalid color conversion library: " +
905932
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
906933
}
934+
if (!preAllocatedOutputTensor.has_value()) {
935+
// We only convert to CHW if a pre-allocated tensor wasn't passed. When a
936+
// pre-allocated tensor is passed, it's up to the caller (typically a
937+
// batch API) to do the conversion. This is more efficient as it allows
938+
// batch NHWC tensors to be permuted only once, instead of permuting HWC
939+
// tensors N times.
940+
output.frame = MaybePermuteHWC2CHW(streamInfo.options, output.frame);
941+
}
907942

908943
} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
909944
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
@@ -980,7 +1015,8 @@ void VideoDecoder::validateFrameIndex(
9801015

9811016
VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
9821017
int streamIndex,
983-
int64_t frameIndex) {
1018+
int64_t frameIndex,
1019+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
9841020
validateUserProvidedStreamIndex(streamIndex);
9851021
validateScannedAllStreams("getFrameAtIndex");
9861022

@@ -989,7 +1025,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
9891025

9901026
int64_t pts = stream.allFrames[frameIndex].pts;
9911027
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
992-
return getNextDecodedOutputNoDemux();
1028+
return getNextDecodedOutputNoDemux(preAllocatedOutputTensor);
9931029
}
9941030

9951031
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
@@ -999,40 +1035,25 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
9991035
validateScannedAllStreams("getFramesAtIndices");
10001036

10011037
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
1002-
const auto& options = streams_[streamIndex].options;
1038+
const auto& stream = streams_[streamIndex];
1039+
const auto& options = stream.options;
10031040
BatchDecodedOutput output(frameIndices.size(), options, streamMetadata);
10041041

1005-
int i = 0;
1006-
const auto& stream = streams_[streamIndex];
1007-
for (int64_t frameIndex : frameIndices) {
1042+
for (auto f = 0; f < frameIndices.size(); ++f) {
1043+
auto frameIndex = frameIndices[f];
10081044
if (frameIndex < 0 || frameIndex >= stream.allFrames.size()) {
10091045
throw std::runtime_error(
10101046
"Invalid frame index=" + std::to_string(frameIndex));
10111047
}
1012-
int64_t pts = stream.allFrames[frameIndex].pts;
1013-
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
1014-
auto rawSingleOutput = getNextRawDecodedOutputNoDemux();
1015-
if (stream.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
1016-
// We are using sws_scale to convert the frame to tensor. sws_scale can
1017-
// convert to a pre-allocated buffer so we can do the color-conversion
1018-
// in-place on the output tensor's data_ptr.
1019-
rawSingleOutput.data = output.frames[i].data_ptr<uint8_t>();
1020-
convertFrameToBufferUsingSwsScale(rawSingleOutput);
1021-
} else if (
1022-
stream.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1023-
// We are using a filter graph to convert the frame to tensor. The
1024-
// filter graph returns us an AVFrame allocated by FFMPEG. So we need to
1025-
// copy the AVFrame to the output tensor.
1026-
torch::Tensor frame = convertFrameToTensorUsingFilterGraph(
1027-
rawSingleOutput.streamIndex, rawSingleOutput.frame.get());
1028-
output.frames[i] = frame;
1029-
} else {
1030-
throw std::runtime_error(
1031-
"Invalid color conversion library: " +
1032-
std::to_string(static_cast<int>(stream.colorConversionLibrary)));
1048+
DecodedOutput singleOut =
1049+
getFrameAtIndex(streamIndex, frameIndex, output.frames[f]);
1050+
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1051+
output.frames[f] = singleOut.frame;
10331052
}
1034-
i++;
1053+
// Note that for now we ignore the pts and duration parts of the output,
1054+
// because they're never used in any caller.
10351055
}
1056+
output.frames = MaybePermuteHWC2CHW(options, output.frames);
10361057
return output;
10371058
}
10381059

@@ -1061,12 +1082,14 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10611082
BatchDecodedOutput output(numOutputFrames, options, streamMetadata);
10621083

10631084
for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
1064-
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i);
1065-
output.frames[f] = singleOut.frame;
1085+
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
1086+
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1087+
output.frames[f] = singleOut.frame;
1088+
}
10661089
output.ptsSeconds[f] = singleOut.ptsSeconds;
10671090
output.durationSeconds[f] = singleOut.durationSeconds;
10681091
}
1069-
1092+
output.frames = MaybePermuteHWC2CHW(options, output.frames);
10701093
return output;
10711094
}
10721095

@@ -1119,6 +1142,7 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11191142
// need this special case below.
11201143
if (startSeconds == stopSeconds) {
11211144
BatchDecodedOutput output(0, options, streamMetadata);
1145+
output.frames = MaybePermuteHWC2CHW(options, output.frames);
11221146
return output;
11231147
}
11241148

@@ -1154,11 +1178,14 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11541178
int64_t numFrames = stopFrameIndex - startFrameIndex;
11551179
BatchDecodedOutput output(numFrames, options, streamMetadata);
11561180
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
1157-
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i);
1158-
output.frames[f] = singleOut.frame;
1181+
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
1182+
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1183+
output.frames[f] = singleOut.frame;
1184+
}
11591185
output.ptsSeconds[f] = singleOut.ptsSeconds;
11601186
output.durationSeconds[f] = singleOut.durationSeconds;
11611187
}
1188+
output.frames = MaybePermuteHWC2CHW(options, output.frames);
11621189

11631190
return output;
11641191
}
@@ -1167,15 +1194,15 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
11671194
auto rawOutput =
11681195
getDecodedOutputWithFilter([this](int frameStreamIndex, AVFrame* frame) {
11691196
StreamInfo& activeStream = streams_[frameStreamIndex];
1170-
return frame->pts >=
1171-
activeStream.discardFramesBeforePts;
1197+
return frame->pts >= activeStream.discardFramesBeforePts;
11721198
});
11731199
return rawOutput;
11741200
}
11751201

1176-
VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() {
1202+
VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux(
1203+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
11771204
auto rawOutput = getNextRawDecodedOutputNoDemux();
1178-
return convertAVFrameToDecodedOutput(rawOutput);
1205+
return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor);
11791206
}
11801207

11811208
void VideoDecoder::setCursorPtsInSeconds(double seconds) {
@@ -1285,11 +1312,6 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
12851312
torch::Tensor tensor = torch::from_blob(
12861313
filteredFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
12871314
StreamInfo& activeStream = streams_[streamIndex];
1288-
if (activeStream.options.dimensionOrder == "NCHW") {
1289-
// The docs guaranty this to return a view:
1290-
// https://pytorch.org/docs/stable/generated/torch.permute.html
1291-
tensor = tensor.permute({2, 0, 1});
1292-
}
12931315
return tensor;
12941316
}
12951317

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,15 +214,19 @@ class VideoDecoder {
214214
};
215215
// Decodes the frame where the current cursor position is. It also advances
216216
// the cursor to the next frame.
217-
DecodedOutput getNextDecodedOutputNoDemux();
217+
DecodedOutput getNextDecodedOutputNoDemux(
218+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
218219
// Decodes the first frame in any added stream that is visible at a given
219220
// timestamp. Frames in the video have a presentation timestamp and a
220221
// duration. For example, if a frame has presentation timestamp of 5.0s and a
221222
// duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0).
222223
// i.e. it will be returned when this function is called with seconds=5.0 or
223224
// seconds=5.999, etc.
224225
DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds);
225-
DecodedOutput getFrameAtIndex(int streamIndex, int64_t frameIndex);
226+
DecodedOutput getFrameAtIndex(
227+
int streamIndex,
228+
int64_t frameIndex,
229+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
226230
struct BatchDecodedOutput {
227231
torch::Tensor frames;
228232
torch::Tensor ptsSeconds;
@@ -363,10 +367,13 @@ class VideoDecoder {
363367
int streamIndex,
364368
const AVFrame* frame);
365369
void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput);
366-
DecodedOutput convertAVFrameToDecodedOutput(RawDecodedOutput& rawOutput);
370+
DecodedOutput convertAVFrameToDecodedOutput(
371+
RawDecodedOutput& rawOutput,
372+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
367373
void convertAVFrameToDecodedOutputOnCPU(
368374
RawDecodedOutput& rawOutput,
369-
DecodedOutput& output);
375+
DecodedOutput& output,
376+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
370377

371378
DecoderOptions options_;
372379
ContainerMetadata containerMetadata_;

0 commit comments

Comments
 (0)