Skip to content

Commit f83ada9

Browse files
committed
Pre-allocate tensors when possible to avoid copies
1 parent 025bf27 commit f83ada9

File tree

3 files changed

+57
-22
lines changed

3 files changed

+57
-22
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,8 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
846846
}
847847

848848
VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
849-
VideoDecoder::RawDecodedOutput& rawOutput) {
849+
VideoDecoder::RawDecodedOutput& rawOutput,
850+
torch::Tensor& preAllocatedOutputTensor) {
850851
// Convert the frame to tensor.
851852
DecodedOutput output;
852853
int streamIndex = rawOutput.streamIndex;
@@ -861,8 +862,10 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
861862
output.durationSeconds = ptsToSeconds(
862863
getDuration(frame), formatContext_->streams[streamIndex]->time_base);
863864
if (streamInfo.options.device.type() == torch::kCPU) {
864-
convertAVFrameToDecodedOutputOnCPU(rawOutput, output);
865+
convertAVFrameToDecodedOutputOnCPU(
866+
rawOutput, output, preAllocatedOutputTensor);
865867
} else if (streamInfo.options.device.type() == torch::kCUDA) {
868+
// TODO: handle pre-allocated output tensor
866869
convertAVFrameToDecodedOutputOnCuda(
867870
streamInfo.options.device,
868871
streamInfo.options,
@@ -878,16 +881,21 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
878881

879882
void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
880883
VideoDecoder::RawDecodedOutput& rawOutput,
881-
DecodedOutput& output) {
884+
DecodedOutput& output,
885+
torch::Tensor& preAllocatedOutputTensor) {
882886
int streamIndex = rawOutput.streamIndex;
883887
AVFrame* frame = rawOutput.frame.get();
884888
auto& streamInfo = streams_[streamIndex];
885889
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
886890
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
887-
int width = streamInfo.options.width.value_or(frame->width);
888-
int height = streamInfo.options.height.value_or(frame->height);
889-
torch::Tensor tensor = torch::empty(
890-
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
891+
torch::Tensor tensor;
892+
if (preAllocatedOutputTensor.numel() != 0) {
893+
// TODO: check shape of preAllocatedOutputTensor?
894+
tensor = preAllocatedOutputTensor;
895+
} else {
896+
tensor = allocateOutputTensorFromRawOutput(rawOutput);
897+
}
898+
891899
rawOutput.data = tensor.data_ptr<uint8_t>();
892900
convertFrameToBufferUsingSwsScale(rawOutput);
893901

@@ -912,6 +920,16 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
912920
}
913921
}
914922

923+
torch::Tensor VideoDecoder::allocateOutputTensorFromRawOutput(
924+
RawDecodedOutput& rawOutput) {
925+
AVFrame* frame = rawOutput.frame.get();
926+
StreamInfo& streamInfo = streams_[rawOutput.streamIndex];
927+
int width = streamInfo.options.width.value_or(frame->width);
928+
int height = streamInfo.options.height.value_or(frame->height);
929+
return torch::empty(
930+
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
931+
}
932+
915933
VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux(
916934
double seconds) {
917935
for (auto& [streamIndex, stream] : streams_) {
@@ -945,7 +963,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux(
945963
return seconds >= frameStartTime && seconds < frameEndTime;
946964
});
947965
// Convert the frame to tensor.
948-
return convertAVFrameToDecodedOutput(rawOutput);
966+
auto preAllocatedOutputTensor = allocateOutputTensorFromRawOutput(rawOutput);
967+
return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor);
949968
}
950969

951970
void VideoDecoder::validateUserProvidedStreamIndex(uint64_t streamIndex) {
@@ -980,7 +999,8 @@ void VideoDecoder::validateFrameIndex(
980999

9811000
VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
9821001
int streamIndex,
983-
int64_t frameIndex) {
1002+
int64_t frameIndex,
1003+
torch::Tensor& preAllocatedOutputTensor) {
9841004
validateUserProvidedStreamIndex(streamIndex);
9851005
validateScannedAllStreams("getFrameAtIndex");
9861006

@@ -989,7 +1009,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
9891009

9901010
int64_t pts = stream.allFrames[frameIndex].pts;
9911011
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
992-
return getNextDecodedOutputNoDemux();
1012+
return getNextDecodedOutputNoDemux(preAllocatedOutputTensor);
9931013
}
9941014

9951015
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
@@ -1061,8 +1081,9 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10611081
BatchDecodedOutput output(numOutputFrames, options, streamMetadata);
10621082

10631083
for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
1064-
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i);
1065-
output.frames[f] = singleOut.frame;
1084+
auto preAllocatedOutputTensor = output.frames[f];
1085+
DecodedOutput singleOut =
1086+
getFrameAtIndex(streamIndex, i, preAllocatedOutputTensor);
10661087
output.ptsSeconds[f] = singleOut.ptsSeconds;
10671088
output.durationSeconds[f] = singleOut.durationSeconds;
10681089
}
@@ -1154,8 +1175,9 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11541175
int64_t numFrames = stopFrameIndex - startFrameIndex;
11551176
BatchDecodedOutput output(numFrames, options, streamMetadata);
11561177
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
1157-
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i);
1158-
output.frames[f] = singleOut.frame;
1178+
auto preAllocatedOutputTensor = output.frames[f];
1179+
DecodedOutput singleOut =
1180+
getFrameAtIndex(streamIndex, i, preAllocatedOutputTensor);
11591181
output.ptsSeconds[f] = singleOut.ptsSeconds;
11601182
output.durationSeconds[f] = singleOut.durationSeconds;
11611183
}
@@ -1173,9 +1195,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
11731195
return rawOutput;
11741196
}
11751197

1176-
VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() {
1198+
VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux(
1199+
torch::Tensor& preAllocatedOutputTensor) {
11771200
auto rawOutput = getNextRawDecodedOutputNoDemux();
1178-
return convertAVFrameToDecodedOutput(rawOutput);
1201+
return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor);
11791202
}
11801203

11811204
void VideoDecoder::setCursorPtsInSeconds(double seconds) {

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,15 +214,19 @@ class VideoDecoder {
214214
};
215215
// Decodes the frame where the current cursor position is. It also advances
216216
// the cursor to the next frame.
217-
DecodedOutput getNextDecodedOutputNoDemux();
217+
DecodedOutput getNextDecodedOutputNoDemux(
218+
torch::Tensor& preAllocatedOutputTensor);
218219
// Decodes the first frame in any added stream that is visible at a given
219220
// timestamp. Frames in the video have a presentation timestamp and a
220221
// duration. For example, if a frame has presentation timestamp of 5.0s and a
221222
// duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0).
222223
// i.e. it will be returned when this function is called with seconds=5.0 or
223224
// seconds=5.999, etc.
224225
DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds);
225-
DecodedOutput getFrameAtIndex(int streamIndex, int64_t frameIndex);
226+
DecodedOutput getFrameAtIndex(
227+
int streamIndex,
228+
int64_t frameIndex,
229+
torch::Tensor& preAllocatedOutputTensor);
226230
struct BatchDecodedOutput {
227231
torch::Tensor frames;
228232
torch::Tensor ptsSeconds;
@@ -363,10 +367,14 @@ class VideoDecoder {
363367
int streamIndex,
364368
const AVFrame* frame);
365369
void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput);
366-
DecodedOutput convertAVFrameToDecodedOutput(RawDecodedOutput& rawOutput);
370+
DecodedOutput convertAVFrameToDecodedOutput(
371+
RawDecodedOutput& rawOutput,
372+
torch::Tensor& preAllocatedOutputTensor);
367373
void convertAVFrameToDecodedOutputOnCPU(
368374
RawDecodedOutput& rawOutput,
369-
DecodedOutput& output);
375+
DecodedOutput& output,
376+
torch::Tensor& preAllocatedOutputTensor);
377+
torch::Tensor allocateOutputTensorFromRawOutput(RawDecodedOutput& rawOutput);
370378

371379
DecoderOptions options_;
372380
ContainerMetadata containerMetadata_;

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,10 @@ void seek_to_pts(at::Tensor& decoder, double seconds) {
190190
OpsDecodedOutput get_next_frame(at::Tensor& decoder) {
191191
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
192192
VideoDecoder::DecodedOutput result;
193+
auto preAllocatedOutputTensor = torch::empty({0});
193194
try {
194-
result = videoDecoder->getNextDecodedOutputNoDemux();
195+
result =
196+
videoDecoder->getNextDecodedOutputNoDemux(preAllocatedOutputTensor);
195197
} catch (const VideoDecoder::EndOfFileException& e) {
196198
C10_THROW_ERROR(IndexError, e.what());
197199
}
@@ -214,7 +216,9 @@ OpsDecodedOutput get_frame_at_index(
214216
int64_t stream_index,
215217
int64_t frame_index) {
216218
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
217-
auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index);
219+
auto preAllocatedOutputTensor = torch::empty({0});
220+
auto result = videoDecoder->getFrameAtIndex(
221+
stream_index, frame_index, preAllocatedOutputTensor);
218222
return makeOpsDecodedOutput(result);
219223
}
220224

0 commit comments

Comments
 (0)