Skip to content

Commit 1751c6b

Browse files
committed
Some cleanup
1 parent c6fcf16 commit 1751c6b

File tree

2 files changed

+31
-52
lines changed

2 files changed

+31
-52
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 29 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -553,9 +553,7 @@ void VideoDecoder::addVideoStream(
553553
void VideoDecoder::addAudioStream(int streamIndex) {
554554
addStream(streamIndex, AVMEDIA_TYPE_AUDIO);
555555

556-
// TODO address this, this is currently super limitting. The main thing we'll
557-
// need to handle is the pre-allocation of the output tensor in batch APIs. We
558-
// probably won't be able to pre-allocate anything.
556+
// See correspodning TODO in makeFrameBatchOutput
559557
auto& streamInfo = streamInfos_[activeStreamIndex_];
560558
TORCH_CHECK(
561559
streamInfo.codecContext->frame_size > 0,
@@ -627,21 +625,8 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices(
627625

628626
const auto& streamMetadata =
629627
containerMetadata_.allStreamMetadata[activeStreamIndex_];
630-
const auto& streamInfo = streamInfos_[activeStreamIndex_];
631628

632-
// TODO_CODE_QUALITY Better allocation logic.
633-
FrameBatchOutput frameBatchOutput;
634-
if (streamInfo.avMediaType == AVMEDIA_TYPE_VIDEO) {
635-
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
636-
frameBatchOutput = FrameBatchOutput(
637-
frameIndices.size(), videoStreamOptions, streamMetadata);
638-
} else {
639-
// TODO Handle case if frame_size is not known.
640-
int64_t numSamples = streamInfo.codecContext->frame_size;
641-
int64_t numChannels = getNumChannels(streamInfo.codecContext);
642-
frameBatchOutput =
643-
FrameBatchOutput(frameIndices.size(), numChannels, numSamples);
644-
}
629+
FrameBatchOutput frameBatchOutput = makeFrameBatchOutput(frameIndices.size());
645630

646631
auto previousIndexInVideo = -1;
647632
for (size_t f = 0; f < frameIndices.size(); ++f) {
@@ -678,7 +663,6 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
678663

679664
const auto& streamMetadata =
680665
containerMetadata_.allStreamMetadata[activeStreamIndex_];
681-
const auto& streamInfo = streamInfos_[activeStreamIndex_];
682666
int64_t numFrames = getNumFrames(streamMetadata);
683667
TORCH_CHECK(
684668
start >= 0, "Range start, " + std::to_string(start) + " is less than 0.");
@@ -691,19 +675,7 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
691675

692676
int64_t numOutputFrames = std::ceil((stop - start) / double(step));
693677

694-
// TODO_CODE_QUALITY Better allocation logic.
695-
FrameBatchOutput frameBatchOutput;
696-
if (streamInfo.avMediaType == AVMEDIA_TYPE_VIDEO) {
697-
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
698-
frameBatchOutput =
699-
FrameBatchOutput(numOutputFrames, videoStreamOptions, streamMetadata);
700-
} else {
701-
// TODO Handle case if frame_size is not known.
702-
int64_t numSamples = streamInfo.codecContext->frame_size;
703-
int64_t numChannels = getNumChannels(streamInfo.codecContext);
704-
frameBatchOutput =
705-
FrameBatchOutput(numOutputFrames, numChannels, numSamples);
706-
}
678+
FrameBatchOutput frameBatchOutput = makeFrameBatchOutput(numOutputFrames);
707679

708680
for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
709681
FrameOutput frameOutput =
@@ -789,17 +761,12 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
789761
double stopSeconds) {
790762
validateActiveStream();
791763

792-
const auto& streamMetadata =
793-
containerMetadata_.allStreamMetadata[activeStreamIndex_];
794764
TORCH_CHECK(
795765
startSeconds <= stopSeconds,
796766
"Start seconds (" + std::to_string(startSeconds) +
797767
") must be less than or equal to stop seconds (" +
798768
std::to_string(stopSeconds) + ".");
799769

800-
const auto& streamInfo = streamInfos_[activeStreamIndex_];
801-
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
802-
803770
// Special case needed to implement a half-open range. At first glance, this
804771
// may seem unnecessary, as our search for stopFrame can return the end, and
805772
// we don't include stopFramIndex in our output. However, consider the
@@ -818,12 +785,13 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
818785
// values of the intervals will map to the same frame indices below. Hence, we
819786
// need this special case below.
820787
if (startSeconds == stopSeconds) {
821-
// TODO handle audio here
822-
FrameBatchOutput frameBatchOutput(0, videoStreamOptions, streamMetadata);
788+
FrameBatchOutput frameBatchOutput = makeFrameBatchOutput(0);
823789
frameBatchOutput.data = maybePermuteOutputTensor(frameBatchOutput.data);
824790
return frameBatchOutput;
825791
}
826792

793+
const auto& streamMetadata =
794+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
827795
double minSeconds = getMinSeconds(streamMetadata);
828796
double maxSeconds = getMaxSeconds(streamMetadata);
829797
TORCH_CHECK(
@@ -854,18 +822,8 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
854822
int64_t stopFrameIndex = secondsToIndexUpperBound(stopSeconds);
855823
int64_t numFrames = stopFrameIndex - startFrameIndex;
856824

857-
// TODO_CODE_QUALITY Better allocation logic.
858-
FrameBatchOutput frameBatchOutput;
859-
if (streamInfo.avMediaType == AVMEDIA_TYPE_VIDEO) {
860-
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
861-
frameBatchOutput =
862-
FrameBatchOutput(numFrames, videoStreamOptions, streamMetadata);
863-
} else {
864-
// TODO Handle case if frame_size is not known.
865-
int64_t numSamples = streamInfo.codecContext->frame_size;
866-
int64_t numChannels = getNumChannels(streamInfo.codecContext);
867-
frameBatchOutput = FrameBatchOutput(numFrames, numChannels, numSamples);
868-
}
825+
FrameBatchOutput frameBatchOutput = makeFrameBatchOutput(numFrames);
826+
869827
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
870828
FrameOutput frameOutput =
871829
getFrameAtIndexInternal(i, frameBatchOutput.data[f]);
@@ -1391,7 +1349,7 @@ VideoDecoder::FrameBatchOutput::FrameBatchOutput(
13911349
height, width, videoStreamOptions.device, numFrames);
13921350
}
13931351

1394-
VideoDecoder::FrameBatchOutput ::FrameBatchOutput(
1352+
VideoDecoder::FrameBatchOutput::FrameBatchOutput(
13951353
int64_t numFrames,
13961354
int64_t numChannels,
13971355
int64_t numSamples)
@@ -1405,6 +1363,26 @@ VideoDecoder::FrameBatchOutput ::FrameBatchOutput(
14051363
data = torch::empty({numFrames, numChannels, numSamples}, tensorOptions);
14061364
}
14071365

1366+
VideoDecoder::FrameBatchOutput VideoDecoder::makeFrameBatchOutput(
1367+
int64_t numFrames) {
1368+
const auto& streamInfo = streamInfos_[activeStreamIndex_];
1369+
if (streamInfo.avMediaType == AVMEDIA_TYPE_VIDEO) {
1370+
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
1371+
const auto& streamMetadata =
1372+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
1373+
return FrameBatchOutput(numFrames, videoStreamOptions, streamMetadata);
1374+
} else {
1375+
// We asserted that frame_size is non-zero when we added the stream, but it
1376+
// may not always be the case.
1377+
// When it's 0, we can't pre-allocate the output tensor as we don't know the
1378+
// number of samples per channel, and it may be non-constant.
1379+
// TODO: handle this.
1380+
int64_t numSamples = streamInfo.codecContext->frame_size;
1381+
int64_t numChannels = getNumChannels(streamInfo.codecContext);
1382+
return FrameBatchOutput(numFrames, numChannels, numSamples);
1383+
}
1384+
}
1385+
14081386
torch::Tensor allocateEmptyHWCTensor(
14091387
int height,
14101388
int width,

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ class VideoDecoder {
162162
torch::Tensor ptsSeconds; // 1D of shape (N,)
163163
torch::Tensor durationSeconds; // 1D of shape (N,)
164164

165-
FrameBatchOutput(){};
166165
explicit FrameBatchOutput(
167166
int64_t numFrames,
168167
const VideoStreamOptions& videoStreamOptions,
@@ -399,6 +398,8 @@ class VideoDecoder {
399398
const AVFrame* avFrame,
400399
torch::Tensor& outputTensor);
401400

401+
FrameBatchOutput makeFrameBatchOutput(int64_t numFrames);
402+
402403
// --------------------------------------------------------------------------
403404
// COLOR CONVERSION LIBRARIES HANDLERS CREATION
404405
// --------------------------------------------------------------------------

0 commit comments

Comments
 (0)