diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index fb8f6342c..cefb1a983 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -46,8 +46,7 @@ class BetaCudaDeviceInterface : public DeviceInterface { void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor = - std::nullopt) override; + std::optional preAllocatedOutputTensor) override; int sendPacket(ReferenceAVPacket& packet) override; int sendEOFPacket() override; diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index fe59fe5bb..8526e4b25 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -35,6 +35,7 @@ void CpuDeviceInterface::initializeVideo( const VideoStreamOptions& videoStreamOptions, const std::vector>& transforms, const std::optional& resizedOutputDims) { + avMediaType_ = AVMEDIA_TYPE_VIDEO; videoStreamOptions_ = videoStreamOptions; resizedOutputDims_ = resizedOutputDims; @@ -86,6 +87,13 @@ void CpuDeviceInterface::initializeVideo( initialized_ = true; } +void CpuDeviceInterface::initializeAudio( + const AudioStreamOptions& audioStreamOptions) { + avMediaType_ = AVMEDIA_TYPE_AUDIO; + audioStreamOptions_ = audioStreamOptions; + initialized_ = true; +} + ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary( const FrameDims& outputDims) const { // swscale requires widths to be multiples of 32: @@ -114,6 +122,20 @@ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary( } } +void CpuDeviceInterface::convertAVFrameToFrameOutput( + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor) { + TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized."); + + if (avMediaType_ == AVMEDIA_TYPE_AUDIO) { + convertAudioAVFrameToFrameOutput(avFrame, frameOutput); + } else { + convertVideoAVFrameToFrameOutput( + avFrame, frameOutput, preAllocatedOutputTensor); + } +} + // Note [preAllocatedOutputTensor with swscale and filtergraph]: // Callers may pass a pre-allocated tensor, where the output.data tensor will // be stored. This parameter is honored in any case, but it only leads to a @@ -123,12 +145,10 @@ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary( // TODO: Figure out whether that's possible! // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of // `dimension_order` parameter. It's up to callers to re-shape it if needed. -void CpuDeviceInterface::convertAVFrameToFrameOutput( +void CpuDeviceInterface::convertVideoAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { - TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized."); - // Note that we ignore the dimensions from the metadata; we don't even bother // storing them. The resized dimensions take priority. If we don't have any, // then we use the dimensions from the actual decoded frame. We use the actual @@ -278,6 +298,123 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( return rgbAVFrameToTensor(filterGraph_->convert(avFrame)); } +void CpuDeviceInterface::convertAudioAVFrameToFrameOutput( + UniqueAVFrame& srcAVFrame, + FrameOutput& frameOutput) { + AVSampleFormat srcSampleFormat = + static_cast(srcAVFrame->format); + AVSampleFormat outSampleFormat = AV_SAMPLE_FMT_FLTP; + + int srcSampleRate = srcAVFrame->sample_rate; + int outSampleRate = audioStreamOptions_.sampleRate.value_or(srcSampleRate); + + int srcNumChannels = getNumChannels(codecContext_); + TORCH_CHECK( + srcNumChannels == getNumChannels(srcAVFrame), + "The frame has ", + getNumChannels(srcAVFrame), + " channels, expected ", + srcNumChannels, + ". If you are hitting this, it may be because you are using " + "a buggy FFmpeg version. FFmpeg4 is known to fail here in some " + "valid scenarios. Try to upgrade FFmpeg?"); + int outNumChannels = audioStreamOptions_.numChannels.value_or(srcNumChannels); + + bool mustConvert = + (srcSampleFormat != outSampleFormat || srcSampleRate != outSampleRate || + srcNumChannels != outNumChannels); + + UniqueAVFrame convertedAVFrame; + if (mustConvert) { + if (!swrContext_) { + swrContext_.reset(createSwrContext( + srcSampleFormat, + outSampleFormat, + srcSampleRate, + outSampleRate, + srcAVFrame, + outNumChannels)); + } + + convertedAVFrame = convertAudioAVFrameSamples( + swrContext_, + srcAVFrame, + outSampleFormat, + outSampleRate, + outNumChannels); + } + const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; + + AVSampleFormat format = static_cast(avFrame->format); + TORCH_CHECK( + format == outSampleFormat, + "Something went wrong, the frame didn't get converted to the desired format. ", + "Desired format = ", + av_get_sample_fmt_name(outSampleFormat), + "source format = ", + av_get_sample_fmt_name(format)); + + int numChannels = getNumChannels(avFrame); + TORCH_CHECK( + numChannels == outNumChannels, + "Something went wrong, the frame didn't get converted to the desired ", + "number of channels = ", + outNumChannels, + ". Got ", + numChannels, + " instead."); + + auto numSamples = avFrame->nb_samples; + + frameOutput.data = torch::empty({numChannels, numSamples}, torch::kFloat32); + + if (numSamples > 0) { + uint8_t* outputChannelData = + static_cast(frameOutput.data.data_ptr()); + auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format); + for (auto channel = 0; channel < numChannels; + ++channel, outputChannelData += numBytesPerChannel) { + std::memcpy( + outputChannelData, + avFrame->extended_data[channel], + numBytesPerChannel); + } + } +} + +std::optional CpuDeviceInterface::maybeFlushAudioBuffers() { + // When sample rate conversion is involved, swresample buffers some of the + // samples in-between calls to swr_convert (see the libswresample docs). + // That's because the last few samples in a given frame require future + // samples from the next frame to be properly converted. This function + // flushes out the samples that are stored in swresample's buffers. + if (!swrContext_) { + return std::nullopt; + } + auto numRemainingSamples = // this is an upper bound + swr_get_out_samples(swrContext_.get(), 0); + + if (numRemainingSamples == 0) { + return std::nullopt; + } + + int numChannels = + audioStreamOptions_.numChannels.value_or(getNumChannels(codecContext_)); + torch::Tensor lastSamples = + torch::empty({numChannels, numRemainingSamples}, torch::kFloat32); + + std::vector outputBuffers(numChannels); + for (auto i = 0; i < numChannels; i++) { + outputBuffers[i] = static_cast(lastSamples[i].data_ptr()); + } + + auto actualNumRemainingSamples = swr_convert( + swrContext_.get(), outputBuffers.data(), numRemainingSamples, nullptr, 0); + + return lastSamples.narrow( + /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples); +} + std::string CpuDeviceInterface::getDetails() { return std::string("CPU Device Interface."); } diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 65ef272bc..2d1033074 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -33,15 +33,28 @@ class CpuDeviceInterface : public DeviceInterface { const std::vector>& transforms, const std::optional& resizedOutputDims) override; + virtual void initializeAudio( + const AudioStreamOptions& audioStreamOptions) override; + + virtual std::optional maybeFlushAudioBuffers() override; + void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor = - std::nullopt) override; + std::optional preAllocatedOutputTensor) override; std::string getDetails() override; private: + void convertAudioAVFrameToFrameOutput( + UniqueAVFrame& srcAVFrame, + FrameOutput& frameOutput); + + void convertVideoAVFrameToFrameOutput( + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor); + int convertAVFrameToTensorUsingSwScale( const UniqueAVFrame& avFrame, torch::Tensor& outputTensor, @@ -108,6 +121,10 @@ class CpuDeviceInterface : public DeviceInterface { bool userRequestedSwScale_; bool initialized_ = false; + + // Audio-specific members + AudioStreamOptions audioStreamOptions_; + UniqueSwrContext swrContext_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 29f6b8755..c892bd49b 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -37,8 +37,7 @@ class CudaDeviceInterface : public DeviceInterface { void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor = - std::nullopt) override; + std::optional preAllocatedOutputTensor) override; std::string getDetails() override; diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 1f7636d99..319fe01a8 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -65,6 +65,21 @@ class DeviceInterface { transforms, [[maybe_unused]] const std::optional& resizedOutputDims) {} + // Initialize the device with parameters specific to audio decoding. There is + // a default empty implementation. + virtual void initializeAudio( + [[maybe_unused]] const AudioStreamOptions& audioStreamOptions) {} + + // Flush any remaining samples from the audio resampler buffer. + // When sample rate conversion is involved, some samples may be buffered + // between frames for proper interpolation. This function flushes those + // buffered samples. + // Returns an optional tensor containing the flushed samples, or std::nullopt + // if there are no buffered samples or audio is not supported. + virtual std::optional maybeFlushAudioBuffers() { + return std::nullopt; + } + // In order for decoding to actually happen on an FFmpeg managed hardware // device, we need to register the DeviceInterface managed // AVHardwareDeviceContext with the AVCodecContext. We don't need to do this @@ -126,6 +141,7 @@ class DeviceInterface { protected: torch::Device device_; SharedAVCodecContext codecContext_; + AVMediaType avMediaType_; }; using CreateDeviceInterfaceFn = diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 7ee758172..72cd7afac 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -570,6 +570,9 @@ void SingleStreamDecoder::addAudioStream( // support that format, but it looks like it does nothing, so this probably // doesn't hurt. streamInfo.codecContext->request_sample_fmt = AV_SAMPLE_FMT_FLTP; + + // Initialize device interface for audio + deviceInterface_->initializeAudio(audioStreamOptions); } // -------------------------------------------------------------------------- @@ -1025,7 +1028,7 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( (stopPts <= lastDecodedAvFrameEnd); } - auto lastSamples = maybeFlushSwrBuffers(); + auto lastSamples = deviceInterface_->maybeFlushAudioBuffers(); if (lastSamples.has_value()) { frames.push_back(*lastSamples); } @@ -1286,143 +1289,17 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput( std::optional preAllocatedOutputTensor) { // Convert the frame to tensor. FrameOutput frameOutput; - auto& streamInfo = streamInfos_[activeStreamIndex_]; frameOutput.ptsSeconds = ptsToSeconds( getPtsOrDts(avFrame), formatContext_->streams[activeStreamIndex_]->time_base); frameOutput.durationSeconds = ptsToSeconds( getDuration(avFrame), formatContext_->streams[activeStreamIndex_]->time_base); - if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { - convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput); - } else { - deviceInterface_->convertAVFrameToFrameOutput( - avFrame, frameOutput, preAllocatedOutputTensor); - } + deviceInterface_->convertAVFrameToFrameOutput( + avFrame, frameOutput, preAllocatedOutputTensor); return frameOutput; } -void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( - UniqueAVFrame& srcAVFrame, - FrameOutput& frameOutput) { - AVSampleFormat srcSampleFormat = - static_cast(srcAVFrame->format); - AVSampleFormat outSampleFormat = AV_SAMPLE_FMT_FLTP; - - StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - int srcSampleRate = srcAVFrame->sample_rate; - int outSampleRate = - streamInfo.audioStreamOptions.sampleRate.value_or(srcSampleRate); - - int srcNumChannels = getNumChannels(streamInfo.codecContext); - TORCH_CHECK( - srcNumChannels == getNumChannels(srcAVFrame), - "The frame has ", - getNumChannels(srcAVFrame), - " channels, expected ", - srcNumChannels, - ". If you are hitting this, it may be because you are using " - "a buggy FFmpeg version. FFmpeg4 is known to fail here in some " - "valid scenarios. Try to upgrade FFmpeg?"); - int outNumChannels = - streamInfo.audioStreamOptions.numChannels.value_or(srcNumChannels); - - bool mustConvert = - (srcSampleFormat != outSampleFormat || srcSampleRate != outSampleRate || - srcNumChannels != outNumChannels); - - UniqueAVFrame convertedAVFrame; - if (mustConvert) { - if (!swrContext_) { - swrContext_.reset(createSwrContext( - srcSampleFormat, - outSampleFormat, - srcSampleRate, - outSampleRate, - srcAVFrame, - outNumChannels)); - } - - convertedAVFrame = convertAudioAVFrameSamples( - swrContext_, - srcAVFrame, - outSampleFormat, - outSampleRate, - outNumChannels); - } - const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; - - AVSampleFormat format = static_cast(avFrame->format); - TORCH_CHECK( - format == outSampleFormat, - "Something went wrong, the frame didn't get converted to the desired format. ", - "Desired format = ", - av_get_sample_fmt_name(outSampleFormat), - "source format = ", - av_get_sample_fmt_name(format)); - - int numChannels = getNumChannels(avFrame); - TORCH_CHECK( - numChannels == outNumChannels, - "Something went wrong, the frame didn't get converted to the desired ", - "number of channels = ", - outNumChannels, - ". Got ", - numChannels, - " instead."); - - auto numSamples = avFrame->nb_samples; // per channel - - frameOutput.data = torch::empty({numChannels, numSamples}, torch::kFloat32); - - if (numSamples > 0) { - uint8_t* outputChannelData = - static_cast(frameOutput.data.data_ptr()); - auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format); - for (auto channel = 0; channel < numChannels; - ++channel, outputChannelData += numBytesPerChannel) { - std::memcpy( - outputChannelData, - avFrame->extended_data[channel], - numBytesPerChannel); - } - } -} - -std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { - // When sample rate conversion is involved, swresample buffers some of the - // samples in-between calls to swr_convert (see the libswresample docs). - // That's because the last few samples in a given frame require future - // samples from the next frame to be properly converted. This function - // flushes out the samples that are stored in swresample's buffers. - auto& streamInfo = streamInfos_[activeStreamIndex_]; - if (!swrContext_) { - return std::nullopt; - } - auto numRemainingSamples = // this is an upper bound - swr_get_out_samples(swrContext_.get(), 0); - - if (numRemainingSamples == 0) { - return std::nullopt; - } - - int numChannels = streamInfo.audioStreamOptions.numChannels.value_or( - getNumChannels(streamInfo.codecContext)); - torch::Tensor lastSamples = - torch::empty({numChannels, numRemainingSamples}, torch::kFloat32); - - std::vector outputBuffers(numChannels); - for (auto i = 0; i < numChannels; i++) { - outputBuffers[i] = static_cast(lastSamples[i].data_ptr()); - } - - auto actualNumRemainingSamples = swr_convert( - swrContext_.get(), outputBuffers.data(), numRemainingSamples, nullptr, 0); - - return lastSamples.narrow( - /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples); -} - // -------------------------------------------------------------------------- // OUTPUT ALLOCATION AND SHAPE CONVERSION // -------------------------------------------------------------------------- diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index d39f6db00..4b41811ff 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -268,11 +268,10 @@ class SingleStreamDecoder { UniqueAVFrame& avFrame, std::optional preAllocatedOutputTensor = std::nullopt); - void convertAudioAVFrameToFrameOutputOnCPU( - UniqueAVFrame& srcAVFrame, - FrameOutput& frameOutput); - - std::optional maybeFlushSwrBuffers(); + void convertAVFrameToFrameOutputOnCPU( + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = std::nullopt); // -------------------------------------------------------------------------- // PTS <-> INDEX CONVERSIONS @@ -347,11 +346,6 @@ class SingleStreamDecoder { int64_t lastDecodedAvFramePts_ = 0; int64_t lastDecodedAvFrameDuration_ = 0; - // Audio only. We cache it for performance. The video equivalents live in - // deviceInterface_. We store swrContext_ here because we only handle audio - // on the CPU. - UniqueSwrContext swrContext_; - // Stores various internal decoding stats. DecodeStats decodeStats_;