@@ -512,14 +512,18 @@ void VideoDecoder::addVideoStream(
512512 videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
513513}
514514
515- void VideoDecoder::addAudioStream (int streamIndex) {
515+ void VideoDecoder::addAudioStream (
516+ int streamIndex,
517+ const AudioStreamOptions& audioStreamOptions) {
516518 TORCH_CHECK (
517519 seekMode_ == SeekMode::approximate,
518520 " seek_mode must be 'approximate' for audio streams." );
519521
520522 addStream (streamIndex, AVMEDIA_TYPE_AUDIO);
521523
522524 auto & streamInfo = streamInfos_[activeStreamIndex_];
525+ streamInfo.audioStreamOptions = audioStreamOptions;
526+
523527 auto & streamMetadata =
524528 containerMetadata_.allStreamMetadata [activeStreamIndex_];
525529 streamMetadata.sampleRate =
@@ -879,6 +883,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
879883 (stopPts <= lastDecodedAvFrameEnd);
880884 }
881885
886+ auto lastSamples = maybeFlushSwrBuffers ();
887+ if (lastSamples.has_value ()) {
888+ frames.push_back (*lastSamples);
889+ }
890+
882891 return AudioFramesOutput{torch::cat (frames, 1 ), firstFramePtsSeconds};
883892}
884893
@@ -1132,8 +1141,7 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
11321141 getDuration (avFrame),
11331142 formatContext_->streams [activeStreamIndex_]->time_base );
11341143 if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1135- convertAudioAVFrameToFrameOutputOnCPU (
1136- avFrame, frameOutput, preAllocatedOutputTensor);
1144+ convertAudioAVFrameToFrameOutputOnCPU (avFrame, frameOutput);
11371145 } else if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
11381146 convertAVFrameToFrameOutputOnCPU (
11391147 avFrame, frameOutput, preAllocatedOutputTensor);
@@ -1311,24 +1319,30 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
13111319
13121320void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU (
13131321 UniqueAVFrame& srcAVFrame,
1314- FrameOutput& frameOutput,
1315- std::optional<torch::Tensor> preAllocatedOutputTensor) {
1316- TORCH_CHECK (
1317- !preAllocatedOutputTensor.has_value (),
1318- " pre-allocated audio tensor not supported yet." );
1319-
1322+ FrameOutput& frameOutput) {
13201323 AVSampleFormat sourceSampleFormat =
13211324 static_cast <AVSampleFormat>(srcAVFrame->format );
13221325 AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
13231326
1327+ int sourceSampleRate = srcAVFrame->sample_rate ;
1328+ int desiredSampleRate =
1329+ streamInfos_[activeStreamIndex_].audioStreamOptions .sampleRate .value_or (
1330+ sourceSampleRate);
1331+
1332+ bool mustConvert =
1333+ (sourceSampleFormat != desiredSampleFormat ||
1334+ sourceSampleRate != desiredSampleRate);
1335+
13241336 UniqueAVFrame convertedAVFrame;
1325- if (sourceSampleFormat != desiredSampleFormat) {
1326- convertedAVFrame = convertAudioAVFrameSampleFormat (
1327- srcAVFrame, sourceSampleFormat, desiredSampleFormat);
1337+ if (mustConvert) {
1338+ convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate (
1339+ srcAVFrame,
1340+ sourceSampleFormat,
1341+ desiredSampleFormat,
1342+ sourceSampleRate,
1343+ desiredSampleRate);
13281344 }
1329- const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
1330- ? convertedAVFrame
1331- : srcAVFrame;
1345+ const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
13321346
13331347 AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
13341348 TORCH_CHECK (
@@ -1351,55 +1365,110 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13511365 memcpy (
13521366 outputChannelData, avFrame->extended_data [channel], numBytesPerChannel);
13531367 }
1368+
13541369 frameOutput.data = outputData;
13551370}
13561371
1357- UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat (
1358- const UniqueAVFrame& avFrame ,
1372+ UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate (
1373+ const UniqueAVFrame& srcAVFrame ,
13591374 AVSampleFormat sourceSampleFormat,
1360- AVSampleFormat desiredSampleFormat
1361-
1362- ) {
1375+ AVSampleFormat desiredSampleFormat,
1376+ int sourceSampleRate,
1377+ int desiredSampleRate ) {
13631378 auto & streamInfo = streamInfos_[activeStreamIndex_];
1364- const auto & streamMetadata =
1365- containerMetadata_.allStreamMetadata [activeStreamIndex_];
1366- int sampleRate = static_cast <int >(streamMetadata.sampleRate .value ());
13671379
13681380 if (!streamInfo.swrContext ) {
13691381 createSwrContext (
1370- streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat);
1382+ streamInfo,
1383+ sourceSampleFormat,
1384+ desiredSampleFormat,
1385+ sourceSampleRate,
1386+ desiredSampleRate);
13711387 }
13721388
13731389 UniqueAVFrame convertedAVFrame (av_frame_alloc ());
13741390 TORCH_CHECK (
13751391 convertedAVFrame,
13761392 " Could not allocate frame for sample format conversion." );
13771393
1378- setChannelLayout (convertedAVFrame, avFrame );
1394+ setChannelLayout (convertedAVFrame, srcAVFrame );
13791395 convertedAVFrame->format = static_cast <int >(desiredSampleFormat);
1380- convertedAVFrame->sample_rate = avFrame->sample_rate ;
1381- convertedAVFrame->nb_samples = avFrame->nb_samples ;
1396+ convertedAVFrame->sample_rate = desiredSampleRate;
1397+ if (sourceSampleRate != desiredSampleRate) {
1398+ // Note that this is an upper bound on the number of output samples.
1399+ // `swr_convert()` will likely not fill convertedAVFrame with that many
1400+ // samples if sample rate conversion is needed. It will buffer the last few
1401+ // ones because those require future samples. That's also why we reset
1402+ // nb_samples after the call to `swr_convert()`.
1403+ // We could also use `swr_get_out_samples()` to determine the number of
1404+ // output samples, but empirically `av_rescale_rnd()` seems to provide a
1405+ // tighter bound.
1406+ convertedAVFrame->nb_samples = av_rescale_rnd (
1407+ swr_get_delay (streamInfo.swrContext .get (), sourceSampleRate) +
1408+ srcAVFrame->nb_samples ,
1409+ desiredSampleRate,
1410+ sourceSampleRate,
1411+ AV_ROUND_UP);
1412+ } else {
1413+ convertedAVFrame->nb_samples = srcAVFrame->nb_samples ;
1414+ }
13821415
13831416 auto status = av_frame_get_buffer (convertedAVFrame.get (), 0 );
13841417 TORCH_CHECK (
13851418 status == AVSUCCESS,
13861419 " Could not allocate frame buffers for sample format conversion: " ,
13871420 getFFMPEGErrorStringFromErrorCode (status));
13881421
1389- auto numSampleConverted = swr_convert (
1422+ auto numConvertedSamples = swr_convert (
13901423 streamInfo.swrContext .get (),
13911424 convertedAVFrame->data ,
13921425 convertedAVFrame->nb_samples ,
1393- static_cast <const uint8_t **>(const_cast <const uint8_t **>(avFrame->data )),
1394- avFrame->nb_samples );
1426+ static_cast <const uint8_t **>(
1427+ const_cast <const uint8_t **>(srcAVFrame->data )),
1428+ srcAVFrame->nb_samples );
13951429 TORCH_CHECK (
1396- numSampleConverted > 0 ,
1430+ numConvertedSamples > 0 ,
13971431 " Error in swr_convert: " ,
1398- getFFMPEGErrorStringFromErrorCode (numSampleConverted));
1432+ getFFMPEGErrorStringFromErrorCode (numConvertedSamples));
1433+
1434+ // See comment above about nb_samples
1435+ convertedAVFrame->nb_samples = numConvertedSamples;
13991436
14001437 return convertedAVFrame;
14011438}
14021439
1440+ std::optional<torch::Tensor> VideoDecoder::maybeFlushSwrBuffers () {
1441+ // When sample rate conversion is involved, swresample buffers some of the
1442+ // samples in-between calls to swr_convert (see the libswresample docs).
1443+ // That's because the last few samples in a given frame require future samples
1444+ // from the next frame to be properly converted. This function flushes out the
1445+ // samples that are stored in swresample's buffers.
1446+ auto & streamInfo = streamInfos_[activeStreamIndex_];
1447+ if (!streamInfo.swrContext ) {
1448+ return std::nullopt ;
1449+ }
1450+ auto numRemainingSamples = // this is an upper bound
1451+ swr_get_out_samples (streamInfo.swrContext .get (), 0 );
1452+
1453+ if (numRemainingSamples == 0 ) {
1454+ return std::nullopt ;
1455+ }
1456+
1457+ torch::Tensor lastSamples = torch::empty (
1458+ {getNumChannels (streamInfo.codecContext ), numRemainingSamples},
1459+ torch::kFloat32 );
1460+ uint8_t * lastSamplesData = static_cast <uint8_t *>(lastSamples.data_ptr ());
1461+
1462+ auto actualNumRemainingSamples = swr_convert (
1463+ streamInfo.swrContext .get (),
1464+ &lastSamplesData,
1465+ numRemainingSamples,
1466+ nullptr ,
1467+ 0 );
1468+ return lastSamples.narrow (
1469+ /* dim=*/ 1 , /* start=*/ 0 , /* length=*/ actualNumRemainingSamples);
1470+ }
1471+
14031472// --------------------------------------------------------------------------
14041473// OUTPUT ALLOCATION AND SHAPE CONVERSION
14051474// --------------------------------------------------------------------------
@@ -1635,14 +1704,16 @@ void VideoDecoder::createSwsContext(
16351704
16361705void VideoDecoder::createSwrContext (
16371706 StreamInfo& streamInfo,
1638- int sampleRate,
16391707 AVSampleFormat sourceSampleFormat,
1640- AVSampleFormat desiredSampleFormat) {
1708+ AVSampleFormat desiredSampleFormat,
1709+ int sourceSampleRate,
1710+ int desiredSampleRate) {
16411711 auto swrContext = allocateSwrContext (
16421712 streamInfo.codecContext ,
1643- sampleRate,
16441713 sourceSampleFormat,
1645- desiredSampleFormat);
1714+ desiredSampleFormat,
1715+ sourceSampleRate,
1716+ desiredSampleRate);
16461717
16471718 auto status = swr_init (swrContext);
16481719 TORCH_CHECK (
0 commit comments