@@ -23,6 +23,7 @@ extern "C" {
2323#include < libavutil/imgutils.h>
2424#include < libavutil/log.h>
2525#include < libavutil/pixdesc.h>
26+ #include < libswresample/swresample.h>
2627#include < libswscale/swscale.h>
2728}
2829
@@ -557,6 +558,12 @@ void VideoDecoder::addAudioStream(int streamIndex) {
557558 static_cast <int64_t >(streamInfo.codecContext ->sample_rate );
558559 streamMetadata.numChannels =
559560 static_cast <int64_t >(getNumChannels (streamInfo.codecContext ));
561+
562+ // FFmpeg docs say that the decoder will try to decode natively in this
563+ // format, if it can. Docs don't say what the decoder does when it doesn't
564+ // support that format, but it looks like it does nothing, so this probably
565+ // doesn't hurt.
566+ streamInfo.codecContext ->request_sample_fmt = AV_SAMPLE_FMT_FLTP;
560567}
561568
562569// --------------------------------------------------------------------------
@@ -1348,37 +1355,89 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13481355 !preAllocatedOutputTensor.has_value (),
13491356 " pre-allocated audio tensor not supported yet." );
13501357
1351- const AVFrame* avFrame = avFrameStream.avFrame .get ();
1358+ AVSampleFormat sourceSampleFormat =
1359+ static_cast <AVSampleFormat>(avFrameStream.avFrame ->format );
1360+ AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
1361+
1362+ UniqueAVFrame convertedAVFrame;
1363+ if (sourceSampleFormat != desiredSampleFormat) {
1364+ convertedAVFrame = convertAudioAVFrameSampleFormat (
1365+ avFrameStream.avFrame , sourceSampleFormat, desiredSampleFormat);
1366+ }
1367+ const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
1368+ ? convertedAVFrame
1369+ : avFrameStream.avFrame ;
1370+
1371+ AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
1372+ TORCH_CHECK (
1373+ format == desiredSampleFormat,
1374+ " Something went wrong, the frame didn't get converted to the desired format. " ,
1375+ " Desired format = " ,
1376+ av_get_sample_fmt_name (desiredSampleFormat),
1377+ " source format = " ,
1378+ av_get_sample_fmt_name (format));
13521379
13531380 auto numSamples = avFrame->nb_samples ; // per channel
13541381 auto numChannels = getNumChannels (avFrame);
13551382 torch::Tensor outputData =
13561383 torch::empty ({numChannels, numSamples}, torch::kFloat32 );
13571384
1358- AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
1359- // TODO-AUDIO Implement all formats.
1360- switch (format) {
1361- case AV_SAMPLE_FMT_FLTP: {
1362- uint8_t * outputChannelData = static_cast <uint8_t *>(outputData.data_ptr ());
1363- auto numBytesPerChannel = numSamples * av_get_bytes_per_sample (format);
1364- for (auto channel = 0 ; channel < numChannels;
1365- ++channel, outputChannelData += numBytesPerChannel) {
1366- memcpy (
1367- outputChannelData,
1368- avFrame->extended_data [channel],
1369- numBytesPerChannel);
1370- }
1371- break ;
1372- }
1373- default :
1374- TORCH_CHECK (
1375- false ,
1376- " Unsupported audio format (yet!): " ,
1377- av_get_sample_fmt_name (format));
1385+ uint8_t * outputChannelData = static_cast <uint8_t *>(outputData.data_ptr ());
1386+ auto numBytesPerChannel = numSamples * av_get_bytes_per_sample (format);
1387+ for (auto channel = 0 ; channel < numChannels;
1388+ ++channel, outputChannelData += numBytesPerChannel) {
1389+ memcpy (
1390+ outputChannelData, avFrame->extended_data [channel], numBytesPerChannel);
13781391 }
13791392 frameOutput.data = outputData;
13801393}
13811394
1395+ UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat (
1396+ const UniqueAVFrame& avFrame,
1397+ AVSampleFormat sourceSampleFormat,
1398+ AVSampleFormat desiredSampleFormat
1399+
1400+ ) {
1401+ auto & streamInfo = streamInfos_[activeStreamIndex_];
1402+ const auto & streamMetadata =
1403+ containerMetadata_.allStreamMetadata [activeStreamIndex_];
1404+ int sampleRate = static_cast <int >(streamMetadata.sampleRate .value ());
1405+
1406+ if (!streamInfo.swrContext ) {
1407+ createSwrContext (
1408+ streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat);
1409+ }
1410+
1411+ UniqueAVFrame convertedAVFrame (av_frame_alloc ());
1412+ TORCH_CHECK (
1413+ convertedAVFrame,
1414+ " Could not allocate frame for sample format conversion." );
1415+
1416+ setChannelLayout (convertedAVFrame, avFrame);
1417+ convertedAVFrame->format = static_cast <int >(desiredSampleFormat);
1418+ convertedAVFrame->sample_rate = avFrame->sample_rate ;
1419+ convertedAVFrame->nb_samples = avFrame->nb_samples ;
1420+
1421+ auto status = av_frame_get_buffer (convertedAVFrame.get (), 0 );
1422+ TORCH_CHECK (
1423+ status == AVSUCCESS,
1424+ " Could not allocate frame buffers for sample format conversion: " ,
1425+ getFFMPEGErrorStringFromErrorCode (status));
1426+
1427+ auto numSampleConverted = swr_convert (
1428+ streamInfo.swrContext .get (),
1429+ convertedAVFrame->data ,
1430+ convertedAVFrame->nb_samples ,
1431+ static_cast <const uint8_t **>(const_cast <const uint8_t **>(avFrame->data )),
1432+ avFrame->nb_samples );
1433+ TORCH_CHECK (
1434+ numSampleConverted > 0 ,
1435+ " Error in swr_convert: " ,
1436+ getFFMPEGErrorStringFromErrorCode (numSampleConverted));
1437+
1438+ return convertedAVFrame;
1439+ }
1440+
13821441// --------------------------------------------------------------------------
13831442// OUTPUT ALLOCATION AND SHAPE CONVERSION
13841443// --------------------------------------------------------------------------
@@ -1612,6 +1671,25 @@ void VideoDecoder::createSwsContext(
16121671 streamInfo.swsContext .reset (swsContext);
16131672}
16141673
1674+ void VideoDecoder::createSwrContext (
1675+ StreamInfo& streamInfo,
1676+ int sampleRate,
1677+ AVSampleFormat sourceSampleFormat,
1678+ AVSampleFormat desiredSampleFormat) {
1679+ auto swrContext = allocateSwrContext (
1680+ streamInfo.codecContext ,
1681+ sampleRate,
1682+ sourceSampleFormat,
1683+ desiredSampleFormat);
1684+
1685+ auto status = swr_init (swrContext);
1686+ TORCH_CHECK (
1687+ status == AVSUCCESS,
1688+ " Couldn't initialize SwrContext: " ,
1689+ getFFMPEGErrorStringFromErrorCode (status));
1690+ streamInfo.swrContext .reset (swrContext);
1691+ }
1692+
16151693// --------------------------------------------------------------------------
16161694// PTS <-> INDEX CONVERSIONS
16171695// --------------------------------------------------------------------------
0 commit comments