@@ -23,6 +23,7 @@ extern "C" {
2323#include < libavutil/imgutils.h>
2424#include < libavutil/log.h>
2525#include < libavutil/pixdesc.h>
26+ #include < libswresample/swresample.h>
2627#include < libswscale/swscale.h>
2728}
2829
@@ -467,6 +468,7 @@ void VideoDecoder::addStream(
467468 TORCH_CHECK_EQ (retVal, AVSUCCESS);
468469
469470 streamInfo.codecContext ->thread_count = ffmpegThreadCount.value_or (0 );
471+ streamInfo.codecContext ->pkt_timebase = streamInfo.stream ->time_base ;
470472
471473 // TODO_CODE_QUALITY same as above.
472474 if (mediaType == AVMEDIA_TYPE_VIDEO && device.type () == torch::kCUDA ) {
@@ -558,6 +560,12 @@ void VideoDecoder::addAudioStream(int streamIndex) {
558560 static_cast <int64_t >(streamInfo.codecContext ->sample_rate );
559561 streamMetadata.numChannels =
560562 static_cast <int64_t >(getNumChannels (streamInfo.codecContext ));
563+
564+ // FFmpeg docs say that the decoder will try to decode natively in this
565+ // format, if it can. Docs don't say what the decoder does when it doesn't
566+ // support that format, but it looks like it does nothing, so this probably
567+ // doesn't hurt.
568+ streamInfo.codecContext ->request_sample_fmt = AV_SAMPLE_FMT_FLTP;
561569}
562570
563571// --------------------------------------------------------------------------
@@ -566,13 +574,15 @@ void VideoDecoder::addAudioStream(int streamIndex) {
566574
567575VideoDecoder::FrameOutput VideoDecoder::getNextFrame () {
568576 auto output = getNextFrameInternal ();
569- output.data = maybePermuteHWC2CHW (output.data );
577+ if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) {
578+ output.data = maybePermuteHWC2CHW (output.data );
579+ }
570580 return output;
571581}
572582
573583VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal (
574584 std::optional<torch::Tensor> preAllocatedOutputTensor) {
575- validateActiveStream (AVMEDIA_TYPE_VIDEO );
585+ validateActiveStream ();
576586 AVFrameStream avFrameStream = decodeAVFrame (
577587 [this ](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
578588 return convertAVFrameToFrameOutput (avFrameStream, preAllocatedOutputTensor);
@@ -868,7 +878,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
868878 // If we need to seek backwards, then we have to seek back to the beginning
869879 // of the stream.
870880 // TODO-AUDIO: document why this is needed in a big comment.
871- setCursorPtsInSeconds (INT64_MIN);
881+ setCursorPtsInSecondsInternal (INT64_MIN);
872882 }
873883
874884 // TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
@@ -914,6 +924,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
914924// --------------------------------------------------------------------------
915925
916926void VideoDecoder::setCursorPtsInSeconds (double seconds) {
927+ validateActiveStream (AVMEDIA_TYPE_VIDEO);
928+ setCursorPtsInSecondsInternal (seconds);
929+ }
930+
931+ void VideoDecoder::setCursorPtsInSecondsInternal (double seconds) {
917932 cursorWasJustSet_ = true ;
918933 cursor_ =
919934 secondsToClosestPts (seconds, streamInfos_[activeStreamIndex_].timeBase );
@@ -1342,37 +1357,89 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13421357 !preAllocatedOutputTensor.has_value (),
13431358 " pre-allocated audio tensor not supported yet." );
13441359
1345- const AVFrame* avFrame = avFrameStream.avFrame .get ();
1360+ AVSampleFormat sourceSampleFormat =
1361+ static_cast <AVSampleFormat>(avFrameStream.avFrame ->format );
1362+ AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
1363+
1364+ UniqueAVFrame convertedAVFrame;
1365+ if (sourceSampleFormat != desiredSampleFormat) {
1366+ convertedAVFrame = convertAudioAVFrameSampleFormat (
1367+ avFrameStream.avFrame , sourceSampleFormat, desiredSampleFormat);
1368+ }
1369+ const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
1370+ ? convertedAVFrame
1371+ : avFrameStream.avFrame ;
1372+
1373+ AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
1374+ TORCH_CHECK (
1375+ format == desiredSampleFormat,
1376+ " Something went wrong, the frame didn't get converted to the desired format. " ,
1377+ " Desired format = " ,
1378+ av_get_sample_fmt_name (desiredSampleFormat),
1379+ " source format = " ,
1380+ av_get_sample_fmt_name (format));
13461381
13471382 auto numSamples = avFrame->nb_samples ; // per channel
13481383 auto numChannels = getNumChannels (avFrame);
13491384 torch::Tensor outputData =
13501385 torch::empty ({numChannels, numSamples}, torch::kFloat32 );
13511386
1352- AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
1353- // TODO-AUDIO Implement all formats.
1354- switch (format) {
1355- case AV_SAMPLE_FMT_FLTP: {
1356- uint8_t * outputChannelData = static_cast <uint8_t *>(outputData.data_ptr ());
1357- auto numBytesPerChannel = numSamples * av_get_bytes_per_sample (format);
1358- for (auto channel = 0 ; channel < numChannels;
1359- ++channel, outputChannelData += numBytesPerChannel) {
1360- memcpy (
1361- outputChannelData,
1362- avFrame->extended_data [channel],
1363- numBytesPerChannel);
1364- }
1365- break ;
1366- }
1367- default :
1368- TORCH_CHECK (
1369- false ,
1370- " Unsupported audio format (yet!): " ,
1371- av_get_sample_fmt_name (format));
1387+ uint8_t * outputChannelData = static_cast <uint8_t *>(outputData.data_ptr ());
1388+ auto numBytesPerChannel = numSamples * av_get_bytes_per_sample (format);
1389+ for (auto channel = 0 ; channel < numChannels;
1390+ ++channel, outputChannelData += numBytesPerChannel) {
1391+ memcpy (
1392+ outputChannelData, avFrame->extended_data [channel], numBytesPerChannel);
13721393 }
13731394 frameOutput.data = outputData;
13741395}
13751396
1397+ UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat (
1398+ const UniqueAVFrame& avFrame,
1399+ AVSampleFormat sourceSampleFormat,
1400+ AVSampleFormat desiredSampleFormat
1401+
1402+ ) {
1403+ auto & streamInfo = streamInfos_[activeStreamIndex_];
1404+ const auto & streamMetadata =
1405+ containerMetadata_.allStreamMetadata [activeStreamIndex_];
1406+ int sampleRate = static_cast <int >(streamMetadata.sampleRate .value ());
1407+
1408+ if (!streamInfo.swrContext ) {
1409+ createSwrContext (
1410+ streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat);
1411+ }
1412+
1413+ UniqueAVFrame convertedAVFrame (av_frame_alloc ());
1414+ TORCH_CHECK (
1415+ convertedAVFrame,
1416+ " Could not allocate frame for sample format conversion." );
1417+
1418+ setChannelLayout (convertedAVFrame, avFrame);
1419+ convertedAVFrame->format = static_cast <int >(desiredSampleFormat);
1420+ convertedAVFrame->sample_rate = avFrame->sample_rate ;
1421+ convertedAVFrame->nb_samples = avFrame->nb_samples ;
1422+
1423+ auto status = av_frame_get_buffer (convertedAVFrame.get (), 0 );
1424+ TORCH_CHECK (
1425+ status == AVSUCCESS,
1426+ " Could not allocate frame buffers for sample format conversion: " ,
1427+ getFFMPEGErrorStringFromErrorCode (status));
1428+
1429+ auto numSampleConverted = swr_convert (
1430+ streamInfo.swrContext .get (),
1431+ convertedAVFrame->data ,
1432+ convertedAVFrame->nb_samples ,
1433+ static_cast <const uint8_t **>(const_cast <const uint8_t **>(avFrame->data )),
1434+ avFrame->nb_samples );
1435+ TORCH_CHECK (
1436+ numSampleConverted > 0 ,
1437+ " Error in swr_convert: " ,
1438+ getFFMPEGErrorStringFromErrorCode (numSampleConverted));
1439+
1440+ return convertedAVFrame;
1441+ }
1442+
13761443// --------------------------------------------------------------------------
13771444// OUTPUT ALLOCATION AND SHAPE CONVERSION
13781445// --------------------------------------------------------------------------
@@ -1606,6 +1673,25 @@ void VideoDecoder::createSwsContext(
16061673 streamInfo.swsContext .reset (swsContext);
16071674}
16081675
1676+ void VideoDecoder::createSwrContext (
1677+ StreamInfo& streamInfo,
1678+ int sampleRate,
1679+ AVSampleFormat sourceSampleFormat,
1680+ AVSampleFormat desiredSampleFormat) {
1681+ auto swrContext = allocateSwrContext (
1682+ streamInfo.codecContext ,
1683+ sampleRate,
1684+ sourceSampleFormat,
1685+ desiredSampleFormat);
1686+
1687+ auto status = swr_init (swrContext);
1688+ TORCH_CHECK (
1689+ status == AVSUCCESS,
1690+ " Couldn't initialize SwrContext: " ,
1691+ getFFMPEGErrorStringFromErrorCode (status));
1692+ streamInfo.swrContext .reset (swrContext);
1693+ }
1694+
16091695// --------------------------------------------------------------------------
16101696// PTS <-> INDEX CONVERSIONS
16111697// --------------------------------------------------------------------------
0 commit comments