Skip to content

Commit 8c11e43

Browse files
committed
Unify allocation within allocateEmptyHWCTensorForStream
1 parent 7cf3a3e commit 8c11e43

File tree

2 files changed

+41
-46
lines changed

2 files changed

+41
-46
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -187,27 +187,30 @@ VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions(
187187
}
188188
}
189189

190-
torch::Tensor makeEmptyHWCTensor(
191-
int height,
192-
int width,
193-
std::optional<int> numFrames = std::nullopt) {
190+
torch::Tensor VideoDecoder::allocateEmptyHWCTensorForStream(
191+
int streamIndex,
192+
std::optional<int> numFrames) {
193+
auto metadata = containerMetadata_.streams[streamIndex];
194+
auto options = streams_[streamIndex].options;
195+
auto height = options.height.value_or(*metadata.height);
196+
auto width = options.width.value_or(*metadata.width);
197+
194198
if (numFrames.has_value()) {
195199
return torch::empty({numFrames.value(), height, width, 3}, {torch::kUInt8});
196200
} else {
197201
return torch::empty({height, width, 3}, {torch::kUInt8});
198202
}
199203
}
200204

201-
VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
202-
int64_t numFrames,
203-
const VideoStreamDecoderOptions& options,
204-
const StreamMetadata& metadata)
205-
: frames(makeEmptyHWCTensor(
206-
options.height.value_or(*metadata.height),
207-
options.width.value_or(*metadata.width),
208-
numFrames)),
209-
ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
210-
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {}
205+
VideoDecoder::BatchDecodedOutput VideoDecoder::allocateBatchDecodedOutput(
206+
int streamIndex,
207+
int64_t numFrames) {
208+
BatchDecodedOutput output;
209+
output.frames = allocateEmptyHWCTensorForStream(streamIndex, numFrames);
210+
output.ptsSeconds = torch::empty({numFrames}, {torch::kFloat64});
211+
output.durationSeconds = torch::empty({numFrames}, {torch::kFloat64});
212+
return output;
213+
}
211214

212215
VideoDecoder::VideoDecoder() {}
213216

@@ -981,14 +984,10 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux(
981984
});
982985
// Convert the frame to tensor.
983986
auto streamIndex = rawOutput.streamIndex;
984-
auto metadata = containerMetadata_.streams[streamIndex];
985-
auto options = streams_[streamIndex].options;
986-
auto height = options.height.value_or(*metadata.height);
987-
auto width = options.width.value_or(*metadata.width);
988-
auto preAllocatedOutputTensor = makeEmptyHWCTensor(height, width);
987+
auto preAllocatedOutputTensor = allocateEmptyHWCTensorForStream(streamIndex);
989988
auto output =
990989
convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor);
991-
output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame);
990+
output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame);
992991
return output;
993992
}
994993

@@ -1025,11 +1024,7 @@ void VideoDecoder::validateFrameIndex(
10251024
VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
10261025
int streamIndex,
10271026
int64_t frameIndex) {
1028-
auto metadata = containerMetadata_.streams[streamIndex];
1029-
auto options = streams_[streamIndex].options;
1030-
auto height = options.height.value_or(*metadata.height);
1031-
auto width = options.width.value_or(*metadata.width);
1032-
auto preAllocatedOutputTensor = makeEmptyHWCTensor(height, width);
1027+
auto preAllocatedOutputTensor = allocateEmptyHWCTensorForStream(streamIndex);
10331028
auto output = getFrameAtIndexInternal(
10341029
streamIndex, frameIndex, preAllocatedOutputTensor);
10351030
output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame);
@@ -1079,7 +1074,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10791074
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
10801075
const auto& stream = streams_[streamIndex];
10811076
const auto& options = stream.options;
1082-
BatchDecodedOutput output(frameIndices.size(), options, streamMetadata);
1077+
BatchDecodedOutput output =
1078+
allocateBatchDecodedOutput(streamIndex, frameIndices.size());
10831079

10841080
auto previousIndexInVideo = -1;
10851081
for (auto f = 0; f < frameIndices.size(); ++f) {
@@ -1171,8 +1167,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
11711167
step > 0, "Step must be greater than 0; is " + std::to_string(step));
11721168

11731169
int64_t numOutputFrames = std::ceil((stop - start) / double(step));
1174-
const auto& options = stream.options;
1175-
BatchDecodedOutput output(numOutputFrames, options, streamMetadata);
1170+
BatchDecodedOutput output =
1171+
allocateBatchDecodedOutput(streamIndex, numOutputFrames);
11761172

11771173
for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
11781174
DecodedOutput singleOut =
@@ -1211,9 +1207,6 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12111207
"; must be less than or equal to " + std::to_string(maxSeconds) +
12121208
").");
12131209

1214-
const auto& stream = streams_[streamIndex];
1215-
const auto& options = stream.options;
1216-
12171210
// Special case needed to implement a half-open range. At first glance, this
12181211
// may seem unnecessary, as our search for stopFrame can return the end, and
12191212
// we don't include stopFramIndex in our output. However, consider the
@@ -1232,7 +1225,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12321225
// values of the intervals will map to the same frame indices below. Hence, we
12331226
// need this special case below.
12341227
if (startSeconds == stopSeconds) {
1235-
BatchDecodedOutput output(0, options, streamMetadata);
1228+
BatchDecodedOutput output = allocateBatchDecodedOutput(streamIndex, 0);
12361229
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
12371230
return output;
12381231
}
@@ -1248,6 +1241,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12481241
// 2. In order to establish if the start of an interval maps to a particular
12491242
// frame, we need to figure out if it is ordered after the frame's pts, but
12501243
// before the next frames's pts.
1244+
const auto& stream = streams_[streamIndex];
12511245
auto startFrame = std::lower_bound(
12521246
stream.allFrames.begin(),
12531247
stream.allFrames.end(),
@@ -1267,7 +1261,8 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12671261
int64_t startFrameIndex = startFrame - stream.allFrames.begin();
12681262
int64_t stopFrameIndex = stopFrame - stream.allFrames.begin();
12691263
int64_t numFrames = stopFrameIndex - startFrameIndex;
1270-
BatchDecodedOutput output(numFrames, options, streamMetadata);
1264+
BatchDecodedOutput output =
1265+
allocateBatchDecodedOutput(streamIndex, numFrames);
12711266
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
12721267
DecodedOutput singleOut =
12731268
getFrameAtIndexInternal(streamIndex, i, output.frames[f]);
@@ -1291,14 +1286,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
12911286
VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() {
12921287
auto rawOutput = getNextRawDecodedOutputNoDemux();
12931288
auto streamIndex = rawOutput.streamIndex;
1294-
auto metadata = containerMetadata_.streams[streamIndex];
1295-
auto options = streams_[streamIndex].options;
1296-
auto height = options.height.value_or(*metadata.height);
1297-
auto width = options.width.value_or(*metadata.width);
1298-
auto preAllocatedOutputTensor = makeEmptyHWCTensor(height, width);
1289+
auto preAllocatedOutputTensor = allocateEmptyHWCTensorForStream(streamIndex);
12991290
auto output =
13001291
convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor);
1301-
output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame);
1292+
output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame);
13021293
return output;
13031294
}
13041295

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,6 @@ class VideoDecoder {
157157
int streamIndex,
158158
const AudioStreamDecoderOptions& options = AudioStreamDecoderOptions());
159159

160-
torch::Tensor MaybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);
161-
162160
// ---- SINGLE FRAME SEEK AND DECODING API ----
163161
// Places the cursor at the first frame on or after the position in seconds.
164162
// Calling getNextFrameOutputNoDemuxInternal() will return the first frame at
@@ -238,12 +236,10 @@ class VideoDecoder {
238236
torch::Tensor frames;
239237
torch::Tensor ptsSeconds;
240238
torch::Tensor durationSeconds;
241-
242-
explicit BatchDecodedOutput(
243-
int64_t numFrames,
244-
const VideoStreamDecoderOptions& options,
245-
const StreamMetadata& metadata);
246239
};
240+
BatchDecodedOutput allocateBatchDecodedOutput(
241+
int streamIndex,
242+
int64_t numFrames);
247243
// Returns frames at the given indices for a given stream as a single stacked
248244
// Tensor.
249245
BatchDecodedOutput getFramesAtIndices(
@@ -302,6 +298,14 @@ class VideoDecoder {
302298

303299
double getPtsSecondsForFrame(int streamIndex, int64_t frameIndex);
304300

301+
// --------------------------------------------------------------------------
302+
// Tensor (frames) manipulation APIs
303+
// --------------------------------------------------------------------------
304+
torch::Tensor MaybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);
305+
torch::Tensor allocateEmptyHWCTensorForStream(
306+
int streamIndex,
307+
std::optional<int> numFrames = std::nullopt);
308+
305309
private:
306310
struct FrameInfo {
307311
int64_t pts = 0;

0 commit comments

Comments
 (0)