@@ -169,6 +169,20 @@ void VideoDecoder::initializeDecoder() {
169169 }
170170 containerMetadata_.numVideoStreams ++;
171171 } else if (avStream->codecpar ->codec_type == AVMEDIA_TYPE_AUDIO) {
172+ // TODO-AUDIO Remove this, we shouldn't need it. We should probably write
173+ // a pts-based "getFramesPlayedInRange" from scratch without going back to
174+ // indices.
175+ int numSamplesPerFrame = avStream->codecpar ->frame_size ;
176+ int sampleRate = avStream->codecpar ->sample_rate ;
177+ if (numSamplesPerFrame > 0 && sampleRate > 0 ) {
178+ // This should allow the approximate mode to do its magic.
179+ // fps is numFrames / duration where
180+ // - duration = numSamplesTotal / sampleRate and
181+ // - numSamplesTotal = numSamplesPerFrame * numFrames
182+ // so fps = numFrames * sampleRate / (numSamplesPerFrame * numFrames)
183+ streamMetadata.averageFps =
184+ static_cast <double >(sampleRate) / numSamplesPerFrame;
185+ }
172186 containerMetadata_.numAudioStreams ++;
173187 }
174188
@@ -549,8 +563,20 @@ void VideoDecoder::addAudioStream(int streamIndex) {
549563 addStream (streamIndex, AVMEDIA_TYPE_AUDIO);
550564
551565 auto & streamInfo = streamInfos_[activeStreamIndex_];
566+
567+ // TODO-AUDIO
568+ TORCH_CHECK (
569+ streamInfo.codecContext ->frame_size > 0 ,
570+ " No support for audio variable framerate yet." );
571+
552572 auto & streamMetadata =
553573 containerMetadata_.allStreamMetadata [activeStreamIndex_];
574+
575+ // TODO-AUDIO
576+ TORCH_CHECK (
577+ streamMetadata.averageFps .has_value (),
578+ " frame_size or sampl_rate aren't known. Cannot decode." );
579+
554580 streamMetadata.sampleRate =
555581 static_cast <int64_t >(streamInfo.codecContext ->sample_rate );
556582 streamMetadata.numChannels = getNumChannels (streamInfo.codecContext );
@@ -562,7 +588,7 @@ void VideoDecoder::addAudioStream(int streamIndex) {
562588
563589VideoDecoder::FrameOutput VideoDecoder::getNextFrame () {
564590 auto output = getNextFrameInternal ();
565- output.data = maybePermuteHWC2CHW (output.data );
591+ output.data = maybePermuteOutputTensor (output.data );
566592 return output;
567593}
568594
@@ -576,6 +602,7 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
576602}
577603
578604VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex (int64_t frameIndex) {
605+ validateActiveStream (AVMEDIA_TYPE_VIDEO);
579606 auto frameOutput = getFrameAtIndexInternal (frameIndex);
580607 frameOutput.data = maybePermuteHWC2CHW (frameOutput.data );
581608 return frameOutput;
@@ -584,7 +611,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
584611VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal (
585612 int64_t frameIndex,
586613 std::optional<torch::Tensor> preAllocatedOutputTensor) {
587- validateActiveStream (AVMEDIA_TYPE_VIDEO );
614+ validateActiveStream ();
588615
589616 const auto & streamInfo = streamInfos_[activeStreamIndex_];
590617 const auto & streamMetadata =
@@ -688,6 +715,7 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
688715}
689716
690717VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt (double seconds) {
718+ validateActiveStream (AVMEDIA_TYPE_VIDEO);
691719 StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
692720 double frameStartTime =
693721 ptsToSeconds (streamInfo.lastDecodedAvFramePts , streamInfo.timeBase );
@@ -759,19 +787,29 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt(
759787VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange (
760788 double startSeconds,
761789 double stopSeconds) {
762- validateActiveStream (AVMEDIA_TYPE_VIDEO);
790+ validateActiveStream ();
791+ // Because we currently never seek with audio streams, we prevent users from
792+ // calling this method twice. We could allow multiple calls in the future.
793+ // Assuming 2 consecutive calls:
794+ // ```
795+ // getFramesPlayedInRange(startSeconds1, stopSeconds1);
796+ // getFramesPlayedInRange(startSeconds2, stopSeconds2);
797+ // ```
798+ // We would need to seek back to 0 iff startSeconds2 <= stopSeconds1. This
799+ // logic is not implemented for now, so we just error.
800+
801+ TORCH_CHECK (
802+ streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO ||
803+ !alreadyCalledGetFramesPlayedInRange_,
804+ " Can only decode once with audio stream. Re-create a decoder object if needed." )
805+ alreadyCalledGetFramesPlayedInRange_ = true ;
763806
764- const auto & streamMetadata =
765- containerMetadata_.allStreamMetadata [activeStreamIndex_];
766807 TORCH_CHECK (
767808 startSeconds <= stopSeconds,
768809 " Start seconds (" + std::to_string (startSeconds) +
769810 " ) must be less than or equal to stop seconds (" +
770811 std::to_string (stopSeconds) + " ." );
771812
772- const auto & streamInfo = streamInfos_[activeStreamIndex_];
773- const auto & videoStreamOptions = streamInfo.videoStreamOptions ;
774-
775813 // Special case needed to implement a half-open range. At first glance, this
776814 // may seem unnecessary, as our search for stopFrame can return the end, and
777815 // we don't include stopFramIndex in our output. However, consider the
@@ -790,11 +828,14 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
790828 // values of the intervals will map to the same frame indices below. Hence, we
791829 // need this special case below.
792830 if (startSeconds == stopSeconds) {
793- FrameBatchOutput frameBatchOutput ( 0 , videoStreamOptions, streamMetadata );
794- frameBatchOutput.data = maybePermuteHWC2CHW (frameBatchOutput.data );
831+ FrameBatchOutput frameBatchOutput = makeFrameBatchOutput ( 0 );
832+ frameBatchOutput.data = maybePermuteOutputTensor (frameBatchOutput.data );
795833 return frameBatchOutput;
796834 }
797835
836+ const auto & streamMetadata =
837+ containerMetadata_.allStreamMetadata [activeStreamIndex_];
838+
798839 double minSeconds = getMinSeconds (streamMetadata);
799840 double maxSeconds = getMaxSeconds (streamMetadata);
800841 TORCH_CHECK (
@@ -825,15 +866,14 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
825866 int64_t stopFrameIndex = secondsToIndexUpperBound (stopSeconds);
826867 int64_t numFrames = stopFrameIndex - startFrameIndex;
827868
828- FrameBatchOutput frameBatchOutput (
829- numFrames, videoStreamOptions, streamMetadata);
869+ FrameBatchOutput frameBatchOutput = makeFrameBatchOutput (numFrames);
830870 for (int64_t i = startFrameIndex, f = 0 ; i < stopFrameIndex; ++i, ++f) {
831871 FrameOutput frameOutput =
832872 getFrameAtIndexInternal (i, frameBatchOutput.data [f]);
833873 frameBatchOutput.ptsSeconds [f] = frameOutput.ptsSeconds ;
834874 frameBatchOutput.durationSeconds [f] = frameOutput.durationSeconds ;
835875 }
836- frameBatchOutput.data = maybePermuteHWC2CHW (frameBatchOutput.data );
876+ frameBatchOutput.data = maybePermuteOutputTensor (frameBatchOutput.data );
837877
838878 return frameBatchOutput;
839879}
@@ -872,8 +912,12 @@ I P P P I P P P I P P I P P I P
872912(2) is more efficient than (1) if there is an I frame between x and y.
873913*/
874914bool VideoDecoder::canWeAvoidSeeking (int64_t targetPts) const {
875- int64_t lastDecodedAvFramePts =
876- streamInfos_.at (activeStreamIndex_).lastDecodedAvFramePts ;
915+ const StreamInfo& streamInfo = streamInfos_.at (activeStreamIndex_);
916+ if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
917+ return true ;
918+ }
919+
920+ int64_t lastDecodedAvFramePts = streamInfo.lastDecodedAvFramePts ;
877921 if (targetPts < lastDecodedAvFramePts) {
878922 // We can never skip a seek if we are seeking backwards.
879923 return false ;
@@ -898,7 +942,7 @@ bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const {
898942// AVFormatContext if it is needed. We can skip seeking in certain cases. See
899943// the comment of canWeAvoidSeeking() for details.
900944void VideoDecoder::maybeSeekToBeforeDesiredPts () {
901- validateActiveStream (AVMEDIA_TYPE_VIDEO );
945+ validateActiveStream ();
902946 StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
903947
904948 int64_t desiredPts =
@@ -945,7 +989,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
945989
946990VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
947991 std::function<bool (AVFrame*)> filterFunction) {
948- validateActiveStream (AVMEDIA_TYPE_VIDEO );
992+ validateActiveStream ();
949993
950994 resetDecodeStats ();
951995
@@ -1075,13 +1119,14 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
10751119 AVFrame* avFrame = avFrameStream.avFrame .get ();
10761120 frameOutput.streamIndex = streamIndex;
10771121 auto & streamInfo = streamInfos_[streamIndex];
1078- TORCH_CHECK (streamInfo.stream ->codecpar ->codec_type == AVMEDIA_TYPE_VIDEO);
10791122 frameOutput.ptsSeconds = ptsToSeconds (
10801123 avFrame->pts , formatContext_->streams [streamIndex]->time_base );
10811124 frameOutput.durationSeconds = ptsToSeconds (
10821125 getDuration (avFrame), formatContext_->streams [streamIndex]->time_base );
1083- // TODO: we should fold preAllocatedOutputTensor into AVFrameStream.
1084- if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
1126+ if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1127+ convertAudioAVFrameToFrameOutputOnCPU (
1128+ avFrameStream, frameOutput, preAllocatedOutputTensor);
1129+ } else if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
10851130 convertAVFrameToFrameOutputOnCPU (
10861131 avFrameStream, frameOutput, preAllocatedOutputTensor);
10871132 } else if (streamInfo.videoStreamOptions .device .type () == torch::kCUDA ) {
@@ -1257,6 +1302,48 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
12571302 filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
12581303}
12591304
1305+ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU (
1306+ VideoDecoder::AVFrameStream& avFrameStream,
1307+ FrameOutput& frameOutput,
1308+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
1309+ const AVFrame* avFrame = avFrameStream.avFrame .get ();
1310+
1311+ auto numSamples = avFrame->nb_samples ; // per channel
1312+ auto numChannels = getNumChannels (avFrame);
1313+
1314+ // TODO-AUDIO: dtype should be format-dependent
1315+ // TODO-AUDIO rename data to something else
1316+ torch::Tensor data;
1317+ if (preAllocatedOutputTensor.has_value ()) {
1318+ data = preAllocatedOutputTensor.value ();
1319+ } else {
1320+ data = torch::empty ({numChannels, numSamples}, torch::kFloat32 );
1321+ }
1322+
1323+ AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
1324+ // TODO Implement all formats
1325+ switch (format) {
1326+ case AV_SAMPLE_FMT_FLTP: {
1327+ uint8_t * outputChannelData = static_cast <uint8_t *>(data.data_ptr ());
1328+ auto numBytesPerChannel = numSamples * av_get_bytes_per_sample (format);
1329+ for (auto channel = 0 ; channel < numChannels;
1330+ ++channel, outputChannelData += numBytesPerChannel) {
1331+ memcpy (
1332+ outputChannelData,
1333+ avFrame->extended_data [channel],
1334+ numBytesPerChannel);
1335+ }
1336+ break ;
1337+ }
1338+ default :
1339+ TORCH_CHECK (
1340+ false ,
1341+ " Unsupported audio format (yet!): " ,
1342+ av_get_sample_fmt_name (format));
1343+ }
1344+ frameOutput.data = data;
1345+ }
1346+
12601347// --------------------------------------------------------------------------
12611348// OUTPUT ALLOCATION AND SHAPE CONVERSION
12621349// --------------------------------------------------------------------------
@@ -1275,6 +1362,41 @@ VideoDecoder::FrameBatchOutput::FrameBatchOutput(
12751362 height, width, videoStreamOptions.device , numFrames);
12761363}
12771364
1365+ VideoDecoder::FrameBatchOutput::FrameBatchOutput (
1366+ int64_t numFrames,
1367+ int64_t numChannels,
1368+ int64_t numSamples)
1369+ : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64 })),
1370+ durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })) {
1371+ // TODO handle dtypes other than float
1372+ auto tensorOptions = torch::TensorOptions ()
1373+ .dtype (torch::kFloat32 )
1374+ .layout (torch::kStrided )
1375+ .device (torch::kCPU );
1376+ data = torch::empty ({numFrames, numChannels, numSamples}, tensorOptions);
1377+ }
1378+
1379+ VideoDecoder::FrameBatchOutput VideoDecoder::makeFrameBatchOutput (
1380+ int64_t numFrames) {
1381+ const auto & streamInfo = streamInfos_[activeStreamIndex_];
1382+ if (streamInfo.avMediaType == AVMEDIA_TYPE_VIDEO) {
1383+ const auto & videoStreamOptions = streamInfo.videoStreamOptions ;
1384+ const auto & streamMetadata =
1385+ containerMetadata_.allStreamMetadata [activeStreamIndex_];
1386+ return FrameBatchOutput (numFrames, videoStreamOptions, streamMetadata);
1387+ } else {
1388+ // TODO-AUDIO
1389+ // We asserted that frame_size is non-zero when we added the stream, but it
1390+ // may not always be the case.
1391+ // When it's 0, we can't pre-allocate the output tensor as we don't know the
1392+ // number of samples per channel, and it may be non-constant. We'll have to
1393+ // find a way to make the batch-APIs work without pre-allocation.
1394+ int64_t numSamples = streamInfo.codecContext ->frame_size ;
1395+ int64_t numChannels = getNumChannels (streamInfo.codecContext );
1396+ return FrameBatchOutput (numFrames, numChannels, numSamples);
1397+ }
1398+ }
1399+
12781400torch::Tensor allocateEmptyHWCTensor (
12791401 int height,
12801402 int width,
@@ -1296,6 +1418,17 @@ torch::Tensor allocateEmptyHWCTensor(
12961418 }
12971419}
12981420
1421+ torch::Tensor VideoDecoder::maybePermuteOutputTensor (
1422+ torch::Tensor& outputTensor) {
1423+ if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) {
1424+ return maybePermuteHWC2CHW (outputTensor);
1425+ } else {
1426+ // No need to do anything for audio. We always return (numChannels,
1427+ // numSamples) or (numFrames, numChannels, numSamples)
1428+ return outputTensor;
1429+ }
1430+ }
1431+
12991432// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
13001433// The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
13011434// or 4D.
0 commit comments