88#include < cstdint>
99#include < cstdio>
1010#include < iostream>
11+ #include < limits>
1112#include < sstream>
1213#include < stdexcept>
1314#include < string_view>
@@ -552,7 +553,8 @@ void VideoDecoder::addAudioStream(int streamIndex) {
552553 containerMetadata_.allStreamMetadata [activeStreamIndex_];
553554 streamMetadata.sampleRate =
554555 static_cast <int64_t >(streamInfo.codecContext ->sample_rate );
555- streamMetadata.numChannels = getNumChannels (streamInfo.codecContext );
556+ streamMetadata.numChannels =
557+ static_cast <int64_t >(getNumChannels (streamInfo.codecContext ));
556558}
557559
558560// --------------------------------------------------------------------------
@@ -567,6 +569,7 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
567569
568570VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal (
569571 std::optional<torch::Tensor> preAllocatedOutputTensor) {
572+ validateActiveStream (AVMEDIA_TYPE_VIDEO);
570573 AVFrameStream avFrameStream = decodeAVFrame (
571574 [this ](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
572575 return convertAVFrameToFrameOutput (avFrameStream, preAllocatedOutputTensor);
@@ -685,6 +688,7 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
685688}
686689
687690VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt (double seconds) {
691+ validateActiveStream (AVMEDIA_TYPE_VIDEO);
688692 StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
689693 double frameStartTime =
690694 ptsToSeconds (streamInfo.lastDecodedAvFramePts , streamInfo.timeBase );
@@ -757,7 +761,6 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
757761 double startSeconds,
758762 double stopSeconds) {
759763 validateActiveStream (AVMEDIA_TYPE_VIDEO);
760-
761764 const auto & streamMetadata =
762765 containerMetadata_.allStreamMetadata [activeStreamIndex_];
763766 TORCH_CHECK (
@@ -835,6 +838,74 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
835838 return frameBatchOutput;
836839}
837840
841+ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio (
842+ double startSeconds,
843+ std::optional<double > stopSecondsOptional) {
844+ validateActiveStream (AVMEDIA_TYPE_AUDIO);
845+
846+ double stopSeconds =
847+ stopSecondsOptional.value_or (std::numeric_limits<double >::max ());
848+
849+ TORCH_CHECK (
850+ startSeconds <= stopSeconds,
851+ " Start seconds (" + std::to_string (startSeconds) +
852+ " ) must be less than or equal to stop seconds (" +
853+ std::to_string (stopSeconds) + " )." );
854+
855+ if (startSeconds == stopSeconds) {
856+ // For consistency with video
857+ return AudioFramesOutput{torch::empty ({0 }), 0.0 };
858+ }
859+
860+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
861+
862+ auto startPts = secondsToClosestPts (startSeconds, streamInfo.timeBase );
863+ if (startPts < streamInfo.lastDecodedAvFramePts +
864+ streamInfo.lastDecodedAvFrameDuration ) {
865+ // If we need to seek backwards, then we have to seek back to the beginning
866+ // of the stream.
867+ // TODO-AUDIO: document why this is needed in a big comment.
868+ setCursorPtsInSeconds (INT64_MIN);
869+ }
870+
871+ // TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
872+ // cat(). This would save a copy. We know the duration of the output and the
873+ // sample rate, so in theory we know the number of output samples.
874+ std::vector<torch::Tensor> frames;
875+
876+ double firstFramePtsSeconds = std::numeric_limits<double >::max ();
877+ auto stopPts = secondsToClosestPts (stopSeconds, streamInfo.timeBase );
878+ auto finished = false ;
879+ while (!finished) {
880+ try {
881+ AVFrameStream avFrameStream = decodeAVFrame ([startPts](AVFrame* avFrame) {
882+ return startPts < avFrame->pts + getDuration (avFrame);
883+ });
884+ // TODO: it's not great that we are getting a FrameOutput, which is
885+ // intended for videos. We should consider bypassing
886+ // convertAVFrameToFrameOutput and directly call
887+ // convertAudioAVFrameToFrameOutputOnCPU.
888+ auto frameOutput = convertAVFrameToFrameOutput (avFrameStream);
889+ firstFramePtsSeconds =
890+ std::min (firstFramePtsSeconds, frameOutput.ptsSeconds );
891+ frames.push_back (frameOutput.data );
892+ } catch (const EndOfFileException& e) {
893+ finished = true ;
894+ }
895+
896+ // If stopSeconds is in [begin, end] of the last decoded frame, we should
897+ // stop decoding more frames. Note that if we were to use [begin, end),
898+ // which may seem more natural, then we would decode the frame starting at
899+ // stopSeconds, which isn't what we want!
900+ auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts +
901+ streamInfo.lastDecodedAvFrameDuration ;
902+ finished |= (streamInfo.lastDecodedAvFramePts ) <= stopPts &&
903+ (stopPts <= lastDecodedAvFrameEnd);
904+ }
905+
906+ return AudioFramesOutput{torch::cat (frames, 1 ), firstFramePtsSeconds};
907+ }
908+
838909// --------------------------------------------------------------------------
839910// SEEKING APIs
840911// --------------------------------------------------------------------------
@@ -871,6 +942,12 @@ I P P P I P P P I P P I P P I P
871942(2) is more efficient than (1) if there is an I frame between x and y.
872943*/
873944bool VideoDecoder::canWeAvoidSeeking () const {
945+ const StreamInfo& streamInfo = streamInfos_.at (activeStreamIndex_);
946+ if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
947+ // For audio, we only need to seek if a backwards seek was requested within
948+ // getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was called.
949+ return !cursorWasJustSet_;
950+ }
874951 int64_t lastDecodedAvFramePts =
875952 streamInfos_.at (activeStreamIndex_).lastDecodedAvFramePts ;
876953 if (cursor_ < lastDecodedAvFramePts) {
@@ -897,7 +974,7 @@ bool VideoDecoder::canWeAvoidSeeking() const {
897974// AVFormatContext if it is needed. We can skip seeking in certain cases. See
898975// the comment of canWeAvoidSeeking() for details.
899976void VideoDecoder::maybeSeekToBeforeDesiredPts () {
900- validateActiveStream (AVMEDIA_TYPE_VIDEO );
977+ validateActiveStream ();
901978 StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
902979
903980 decodeStats_.numSeeksAttempted ++;
@@ -942,7 +1019,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
9421019
9431020VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
9441021 std::function<bool (AVFrame*)> filterFunction) {
945- validateActiveStream (AVMEDIA_TYPE_VIDEO );
1022+ validateActiveStream ();
9461023
9471024 resetDecodeStats ();
9481025
@@ -1071,13 +1148,14 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
10711148 AVFrame* avFrame = avFrameStream.avFrame .get ();
10721149 frameOutput.streamIndex = streamIndex;
10731150 auto & streamInfo = streamInfos_[streamIndex];
1074- TORCH_CHECK (streamInfo.stream ->codecpar ->codec_type == AVMEDIA_TYPE_VIDEO);
10751151 frameOutput.ptsSeconds = ptsToSeconds (
10761152 avFrame->pts , formatContext_->streams [streamIndex]->time_base );
10771153 frameOutput.durationSeconds = ptsToSeconds (
10781154 getDuration (avFrame), formatContext_->streams [streamIndex]->time_base );
1079- // TODO: we should fold preAllocatedOutputTensor into AVFrameStream.
1080- if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
1155+ if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1156+ convertAudioAVFrameToFrameOutputOnCPU (
1157+ avFrameStream, frameOutput, preAllocatedOutputTensor);
1158+ } else if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
10811159 convertAVFrameToFrameOutputOnCPU (
10821160 avFrameStream, frameOutput, preAllocatedOutputTensor);
10831161 } else if (streamInfo.videoStreamOptions .device .type () == torch::kCUDA ) {
@@ -1253,6 +1331,45 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
12531331 filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
12541332}
12551333
1334+ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU (
1335+ VideoDecoder::AVFrameStream& avFrameStream,
1336+ FrameOutput& frameOutput,
1337+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
1338+ TORCH_CHECK (
1339+ !preAllocatedOutputTensor.has_value (),
1340+ " pre-allocated audio tensor not supported yet." );
1341+
1342+ const AVFrame* avFrame = avFrameStream.avFrame .get ();
1343+
1344+ auto numSamples = avFrame->nb_samples ; // per channel
1345+ auto numChannels = getNumChannels (avFrame);
1346+ torch::Tensor outputData =
1347+ torch::empty ({numChannels, numSamples}, torch::kFloat32 );
1348+
1349+ AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
1350+ // TODO-AUDIO Implement all formats.
1351+ switch (format) {
1352+ case AV_SAMPLE_FMT_FLTP: {
1353+ uint8_t * outputChannelData = static_cast <uint8_t *>(outputData.data_ptr ());
1354+ auto numBytesPerChannel = numSamples * av_get_bytes_per_sample (format);
1355+ for (auto channel = 0 ; channel < numChannels;
1356+ ++channel, outputChannelData += numBytesPerChannel) {
1357+ memcpy (
1358+ outputChannelData,
1359+ avFrame->extended_data [channel],
1360+ numBytesPerChannel);
1361+ }
1362+ break ;
1363+ }
1364+ default :
1365+ TORCH_CHECK (
1366+ false ,
1367+ " Unsupported audio format (yet!): " ,
1368+ av_get_sample_fmt_name (format));
1369+ }
1370+ frameOutput.data = outputData;
1371+ }
1372+
12561373// --------------------------------------------------------------------------
12571374// OUTPUT ALLOCATION AND SHAPE CONVERSION
12581375// --------------------------------------------------------------------------
0 commit comments