Skip to content

Commit ae15304

Browse files
committed
Add basic range support
1 parent 79954f9 commit ae15304

File tree

9 files changed

+396
-55
lines changed

9 files changed

+396
-55
lines changed

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,17 @@ int64_t getDuration(const AVFrame* frame) {
6060
#endif
6161
}
6262

63+
int64_t getNumChannels(const AVFrame* avFrame) {
64+
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
65+
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
66+
int numChannels = avFrame->ch_layout.nb_channels;
67+
#else
68+
int numChannels = av_get_channel_layout_nb_channels(avFrame->channel_layout);
69+
#endif
70+
71+
return static_cast<int64_t>(numChannels);
72+
}
73+
6374
int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext) {
6475
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
6576
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode);
139139
int64_t getDuration(const UniqueAVFrame& frame);
140140
int64_t getDuration(const AVFrame* frame);
141141

142+
int64_t getNumChannels(const AVFrame* avFrame);
142143
int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext);
143144

144145
// Returns true if sws_scale can handle unaligned data.

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 153 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,20 @@ void VideoDecoder::initializeDecoder() {
169169
}
170170
containerMetadata_.numVideoStreams++;
171171
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
172+
// TODO-AUDIO Remove this, we shouldn't need it. We should probably write
173+
// a pts-based "getFramesPlayedInRange" from scratch without going back to
174+
// indices.
175+
int numSamplesPerFrame = avStream->codecpar->frame_size;
176+
int sampleRate = avStream->codecpar->sample_rate;
177+
if (numSamplesPerFrame > 0 && sampleRate > 0) {
178+
// This should allow the approximate mode to do its magic.
179+
// fps is numFrames / duration where
180+
// - duration = numSamplesTotal / sampleRate and
181+
// - numSamplesTotal = numSamplesPerFrame * numFrames
182+
// so fps = numFrames * sampleRate / (numSamplesPerFrame * numFrames)
183+
streamMetadata.averageFps =
184+
static_cast<double>(sampleRate) / numSamplesPerFrame;
185+
}
172186
containerMetadata_.numAudioStreams++;
173187
}
174188

@@ -549,8 +563,20 @@ void VideoDecoder::addAudioStream(int streamIndex) {
549563
addStream(streamIndex, AVMEDIA_TYPE_AUDIO);
550564

551565
auto& streamInfo = streamInfos_[activeStreamIndex_];
566+
567+
// TODO-AUDIO
568+
TORCH_CHECK(
569+
streamInfo.codecContext->frame_size > 0,
570+
"No support for audio variable framerate yet.");
571+
552572
auto& streamMetadata =
553573
containerMetadata_.allStreamMetadata[activeStreamIndex_];
574+
575+
// TODO-AUDIO
576+
TORCH_CHECK(
577+
streamMetadata.averageFps.has_value(),
578+
"frame_size or sampl_rate aren't known. Cannot decode.");
579+
554580
streamMetadata.sampleRate =
555581
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
556582
streamMetadata.numChannels = getNumChannels(streamInfo.codecContext);
@@ -562,7 +588,7 @@ void VideoDecoder::addAudioStream(int streamIndex) {
562588

563589
VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
564590
auto output = getNextFrameInternal();
565-
output.data = maybePermuteHWC2CHW(output.data);
591+
output.data = maybePermuteOutputTensor(output.data);
566592
return output;
567593
}
568594

@@ -576,6 +602,7 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
576602
}
577603

578604
VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
605+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
579606
auto frameOutput = getFrameAtIndexInternal(frameIndex);
580607
frameOutput.data = maybePermuteHWC2CHW(frameOutput.data);
581608
return frameOutput;
@@ -584,7 +611,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
584611
VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal(
585612
int64_t frameIndex,
586613
std::optional<torch::Tensor> preAllocatedOutputTensor) {
587-
validateActiveStream(AVMEDIA_TYPE_VIDEO);
614+
validateActiveStream();
588615

589616
const auto& streamInfo = streamInfos_[activeStreamIndex_];
590617
const auto& streamMetadata =
@@ -688,6 +715,7 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
688715
}
689716

690717
VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
718+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
691719
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
692720
double frameStartTime =
693721
ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase);
@@ -759,19 +787,29 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt(
759787
VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
760788
double startSeconds,
761789
double stopSeconds) {
762-
validateActiveStream(AVMEDIA_TYPE_VIDEO);
790+
validateActiveStream();
791+
// Because we currently never seek with audio streams, we prevent users from
792+
// calling this method twice. We could allow multiple calls in the future.
793+
// Assuming 2 consecutive calls:
794+
// ```
795+
// getFramesPlayedInRange(startSeconds1, stopSeconds1);
796+
// getFramesPlayedInRange(startSeconds2, stopSeconds2);
797+
// ```
798+
// We would need to seek back to 0 iff startSeconds2 <= stopSeconds1. This
799+
// logic is not implemented for now, so we just error.
800+
801+
TORCH_CHECK(
802+
streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO ||
803+
!alreadyCalledGetFramesPlayedInRange_,
804+
"Can only decode once with audio stream. Re-create a decoder object if needed.")
805+
alreadyCalledGetFramesPlayedInRange_ = true;
763806

764-
const auto& streamMetadata =
765-
containerMetadata_.allStreamMetadata[activeStreamIndex_];
766807
TORCH_CHECK(
767808
startSeconds <= stopSeconds,
768809
"Start seconds (" + std::to_string(startSeconds) +
769810
") must be less than or equal to stop seconds (" +
770811
std::to_string(stopSeconds) + ".");
771812

772-
const auto& streamInfo = streamInfos_[activeStreamIndex_];
773-
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
774-
775813
// Special case needed to implement a half-open range. At first glance, this
776814
// may seem unnecessary, as our search for stopFrame can return the end, and
777815
// we don't include stopFramIndex in our output. However, consider the
@@ -790,11 +828,14 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
790828
// values of the intervals will map to the same frame indices below. Hence, we
791829
// need this special case below.
792830
if (startSeconds == stopSeconds) {
793-
FrameBatchOutput frameBatchOutput(0, videoStreamOptions, streamMetadata);
794-
frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data);
831+
FrameBatchOutput frameBatchOutput = makeFrameBatchOutput(0);
832+
frameBatchOutput.data = maybePermuteOutputTensor(frameBatchOutput.data);
795833
return frameBatchOutput;
796834
}
797835

836+
const auto& streamMetadata =
837+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
838+
798839
double minSeconds = getMinSeconds(streamMetadata);
799840
double maxSeconds = getMaxSeconds(streamMetadata);
800841
TORCH_CHECK(
@@ -825,15 +866,14 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
825866
int64_t stopFrameIndex = secondsToIndexUpperBound(stopSeconds);
826867
int64_t numFrames = stopFrameIndex - startFrameIndex;
827868

828-
FrameBatchOutput frameBatchOutput(
829-
numFrames, videoStreamOptions, streamMetadata);
869+
FrameBatchOutput frameBatchOutput = makeFrameBatchOutput(numFrames);
830870
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
831871
FrameOutput frameOutput =
832872
getFrameAtIndexInternal(i, frameBatchOutput.data[f]);
833873
frameBatchOutput.ptsSeconds[f] = frameOutput.ptsSeconds;
834874
frameBatchOutput.durationSeconds[f] = frameOutput.durationSeconds;
835875
}
836-
frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data);
876+
frameBatchOutput.data = maybePermuteOutputTensor(frameBatchOutput.data);
837877

838878
return frameBatchOutput;
839879
}
@@ -872,8 +912,12 @@ I P P P I P P P I P P I P P I P
872912
(2) is more efficient than (1) if there is an I frame between x and y.
873913
*/
874914
bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const {
875-
int64_t lastDecodedAvFramePts =
876-
streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts;
915+
const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
916+
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
917+
return true;
918+
}
919+
920+
int64_t lastDecodedAvFramePts = streamInfo.lastDecodedAvFramePts;
877921
if (targetPts < lastDecodedAvFramePts) {
878922
// We can never skip a seek if we are seeking backwards.
879923
return false;
@@ -898,7 +942,7 @@ bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const {
898942
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
899943
// the comment of canWeAvoidSeeking() for details.
900944
void VideoDecoder::maybeSeekToBeforeDesiredPts() {
901-
validateActiveStream(AVMEDIA_TYPE_VIDEO);
945+
validateActiveStream();
902946
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
903947

904948
int64_t desiredPts =
@@ -945,7 +989,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
945989

946990
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
947991
std::function<bool(AVFrame*)> filterFunction) {
948-
validateActiveStream(AVMEDIA_TYPE_VIDEO);
992+
validateActiveStream();
949993

950994
resetDecodeStats();
951995

@@ -1075,13 +1119,14 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
10751119
AVFrame* avFrame = avFrameStream.avFrame.get();
10761120
frameOutput.streamIndex = streamIndex;
10771121
auto& streamInfo = streamInfos_[streamIndex];
1078-
TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO);
10791122
frameOutput.ptsSeconds = ptsToSeconds(
10801123
avFrame->pts, formatContext_->streams[streamIndex]->time_base);
10811124
frameOutput.durationSeconds = ptsToSeconds(
10821125
getDuration(avFrame), formatContext_->streams[streamIndex]->time_base);
1083-
// TODO: we should fold preAllocatedOutputTensor into AVFrameStream.
1084-
if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
1126+
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1127+
convertAudioAVFrameToFrameOutputOnCPU(
1128+
avFrameStream, frameOutput, preAllocatedOutputTensor);
1129+
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
10851130
convertAVFrameToFrameOutputOnCPU(
10861131
avFrameStream, frameOutput, preAllocatedOutputTensor);
10871132
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) {
@@ -1257,6 +1302,48 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
12571302
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
12581303
}
12591304

1305+
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1306+
VideoDecoder::AVFrameStream& avFrameStream,
1307+
FrameOutput& frameOutput,
1308+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1309+
const AVFrame* avFrame = avFrameStream.avFrame.get();
1310+
1311+
auto numSamples = avFrame->nb_samples; // per channel
1312+
auto numChannels = getNumChannels(avFrame);
1313+
1314+
// TODO-AUDIO: dtype should be format-dependent
1315+
// TODO-AUDIO rename data to something else
1316+
torch::Tensor data;
1317+
if (preAllocatedOutputTensor.has_value()) {
1318+
data = preAllocatedOutputTensor.value();
1319+
} else {
1320+
data = torch::empty({numChannels, numSamples}, torch::kFloat32);
1321+
}
1322+
1323+
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
1324+
// TODO Implement all formats
1325+
switch (format) {
1326+
case AV_SAMPLE_FMT_FLTP: {
1327+
uint8_t* outputChannelData = static_cast<uint8_t*>(data.data_ptr());
1328+
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
1329+
for (auto channel = 0; channel < numChannels;
1330+
++channel, outputChannelData += numBytesPerChannel) {
1331+
memcpy(
1332+
outputChannelData,
1333+
avFrame->extended_data[channel],
1334+
numBytesPerChannel);
1335+
}
1336+
break;
1337+
}
1338+
default:
1339+
TORCH_CHECK(
1340+
false,
1341+
"Unsupported audio format (yet!): ",
1342+
av_get_sample_fmt_name(format));
1343+
}
1344+
frameOutput.data = data;
1345+
}
1346+
12601347
// --------------------------------------------------------------------------
12611348
// OUTPUT ALLOCATION AND SHAPE CONVERSION
12621349
// --------------------------------------------------------------------------
@@ -1275,6 +1362,41 @@ VideoDecoder::FrameBatchOutput::FrameBatchOutput(
12751362
height, width, videoStreamOptions.device, numFrames);
12761363
}
12771364

1365+
VideoDecoder::FrameBatchOutput::FrameBatchOutput(
1366+
int64_t numFrames,
1367+
int64_t numChannels,
1368+
int64_t numSamples)
1369+
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
1370+
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {
1371+
// TODO handle dtypes other than float
1372+
auto tensorOptions = torch::TensorOptions()
1373+
.dtype(torch::kFloat32)
1374+
.layout(torch::kStrided)
1375+
.device(torch::kCPU);
1376+
data = torch::empty({numFrames, numChannels, numSamples}, tensorOptions);
1377+
}
1378+
1379+
VideoDecoder::FrameBatchOutput VideoDecoder::makeFrameBatchOutput(
1380+
int64_t numFrames) {
1381+
const auto& streamInfo = streamInfos_[activeStreamIndex_];
1382+
if (streamInfo.avMediaType == AVMEDIA_TYPE_VIDEO) {
1383+
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
1384+
const auto& streamMetadata =
1385+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
1386+
return FrameBatchOutput(numFrames, videoStreamOptions, streamMetadata);
1387+
} else {
1388+
// TODO-AUDIO
1389+
// We asserted that frame_size is non-zero when we added the stream, but it
1390+
// may not always be the case.
1391+
// When it's 0, we can't pre-allocate the output tensor as we don't know the
1392+
// number of samples per channel, and it may be non-constant. We'll have to
1393+
// find a way to make the batch-APIs work without pre-allocation.
1394+
int64_t numSamples = streamInfo.codecContext->frame_size;
1395+
int64_t numChannels = getNumChannels(streamInfo.codecContext);
1396+
return FrameBatchOutput(numFrames, numChannels, numSamples);
1397+
}
1398+
}
1399+
12781400
torch::Tensor allocateEmptyHWCTensor(
12791401
int height,
12801402
int width,
@@ -1296,6 +1418,17 @@ torch::Tensor allocateEmptyHWCTensor(
12961418
}
12971419
}
12981420

1421+
torch::Tensor VideoDecoder::maybePermuteOutputTensor(
1422+
torch::Tensor& outputTensor) {
1423+
if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) {
1424+
return maybePermuteHWC2CHW(outputTensor);
1425+
} else {
1426+
// No need to do anything for audio. We always return (numChannels,
1427+
// numSamples) or (numFrames, numChannels, numSamples)
1428+
return outputTensor;
1429+
}
1430+
}
1431+
12991432
// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
13001433
// The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
13011434
// or 4D.

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ class VideoDecoder {
168168
int64_t numFrames,
169169
const VideoStreamOptions& videoStreamOptions,
170170
const StreamMetadata& streamMetadata);
171+
explicit FrameBatchOutput(
172+
int64_t numFrames,
173+
int64_t numChannels,
174+
int64_t numSamples);
171175
};
172176

173177
// Places the cursor at the first frame on or after the position in seconds.
@@ -372,6 +376,7 @@ class VideoDecoder {
372376
FrameOutput getNextFrameInternal(
373377
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
374378

379+
torch::Tensor maybePermuteOutputTensor(torch::Tensor& outputTensor);
375380
torch::Tensor maybePermuteHWC2CHW(torch::Tensor& hwcTensor);
376381

377382
FrameOutput convertAVFrameToFrameOutput(
@@ -383,12 +388,18 @@ class VideoDecoder {
383388
FrameOutput& frameOutput,
384389
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
385390

391+
void convertAudioAVFrameToFrameOutputOnCPU(
392+
AVFrameStream& avFrameStream,
393+
FrameOutput& frameOutput,
394+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
395+
386396
torch::Tensor convertAVFrameToTensorUsingFilterGraph(const AVFrame* avFrame);
387397

388398
int convertAVFrameToTensorUsingSwsScale(
389399
const AVFrame* avFrame,
390400
torch::Tensor& outputTensor);
391401

402+
FrameBatchOutput makeFrameBatchOutput(int64_t numFrames);
392403
// --------------------------------------------------------------------------
393404
// COLOR CONVERSION LIBRARIES HANDLERS CREATION
394405
// --------------------------------------------------------------------------
@@ -477,6 +488,7 @@ class VideoDecoder {
477488
bool scannedAllStreams_ = false;
478489
// Tracks that we've already been initialized.
479490
bool initialized_ = false;
491+
bool alreadyCalledGetFramesPlayedInRange_ = false;
480492
};
481493

482494
// --------------------------------------------------------------------------

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,12 @@ void add_audio_stream(
232232
}
233233

234234
void seek_to_pts(at::Tensor& decoder, double seconds) {
235+
// TODO-AUDIO we should prevent more than one call to this op for audio
236+
// streams, for the same reasons we do so for getFramesPlayedInRange(). But we
237+
// can't implement the logic here, because we don't know media type (audio vs
238+
// video). We also can't do it within setCursorPtsInSeconds because it's used
239+
// by all other decoding methods. This isn't un-doable, just not easy with
240+
// the API we currently have.
235241
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
236242
videoDecoder->setCursorPtsInSeconds(seconds);
237243
}

0 commit comments

Comments
 (0)