From 159672d02fbe481576c355442cb95d13627ed902 Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Wed, 29 Oct 2025 14:58:22 -0700 Subject: [PATCH 1/2] refactor device interface --- .../_core/BetaCudaDeviceInterface.cpp | 4 +- .../_core/BetaCudaDeviceInterface.h | 1 + src/torchcodec/_core/CpuDeviceInterface.cpp | 141 +++++++++++++++++- src/torchcodec/_core/CpuDeviceInterface.h | 19 +++ src/torchcodec/_core/CudaDeviceInterface.cpp | 4 +- src/torchcodec/_core/CudaDeviceInterface.h | 1 + src/torchcodec/_core/DeviceInterface.h | 16 ++ src/torchcodec/_core/SingleStreamDecoder.cpp | 134 +---------------- src/torchcodec/_core/SingleStreamDecoder.h | 11 -- 9 files changed, 187 insertions(+), 144 deletions(-) diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index b0caa9705..70ad8d5eb 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -671,11 +671,13 @@ void BetaCudaDeviceInterface::flush() { void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, + [[maybe_unused]] AVMediaType mediaType, std::optional preAllocatedOutputTensor) { if (cpuFallback_) { // CPU decoded frame - need to do CPU color conversion then transfer to GPU FrameOutput cpuFrameOutput; - cpuFallback_->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); + cpuFallback_->convertAVFrameToFrameOutput( + avFrame, cpuFrameOutput, AVMEDIA_TYPE_VIDEO); // Transfer CPU frame to GPU if (preAllocatedOutputTensor.has_value()) { diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index 29511df50..a9cf0342a 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -46,6 +46,7 @@ class BetaCudaDeviceInterface : public DeviceInterface { void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, + AVMediaType mediaType, std::optional preAllocatedOutputTensor = std::nullopt) override; diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 5aa20b09e..724940949 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -110,6 +110,12 @@ void CpuDeviceInterface::initializeVideo( initialized_ = true; } +void CpuDeviceInterface::initializeAudio( + const AudioStreamOptions& audioStreamOptions) { + audioStreamOptions_ = audioStreamOptions; + initialized_ = true; +} + ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary( const FrameDims& outputDims) const { // swscale requires widths to be multiples of 32: @@ -138,6 +144,21 @@ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary( } } +void CpuDeviceInterface::convertAVFrameToFrameOutput( + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + AVMediaType mediaType, + std::optional preAllocatedOutputTensor) { + TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized."); + + if (mediaType == AVMEDIA_TYPE_AUDIO) { + convertAudioAVFrameToFrameOutput(avFrame, frameOutput); + } else { + convertVideoAVFrameToFrameOutput( + avFrame, frameOutput, preAllocatedOutputTensor); + } +} + // Note [preAllocatedOutputTensor with swscale and filtergraph]: // Callers may pass a pre-allocated tensor, where the output.data tensor will // be stored. This parameter is honored in any case, but it only leads to a @@ -147,12 +168,10 @@ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary( // TODO: Figure out whether that's possible! // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of // `dimension_order` parameter. It's up to callers to re-shape it if needed. -void CpuDeviceInterface::convertAVFrameToFrameOutput( +void CpuDeviceInterface::convertVideoAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { - TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized."); - // Note that we ignore the dimensions from the metadata; we don't even bother // storing them. The resized dimensions take priority. If we don't have any, // then we use the dimensions from the actual decoded frame. We use the actual @@ -346,6 +365,122 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( return rgbAVFrameToTensor(filterGraph_->convert(avFrame)); } +void CpuDeviceInterface::convertAudioAVFrameToFrameOutput( + UniqueAVFrame& srcAVFrame, + FrameOutput& frameOutput) { + AVSampleFormat srcSampleFormat = + static_cast(srcAVFrame->format); + AVSampleFormat outSampleFormat = AV_SAMPLE_FMT_FLTP; + + int srcSampleRate = srcAVFrame->sample_rate; + int outSampleRate = audioStreamOptions_.sampleRate.value_or(srcSampleRate); + + int srcNumChannels = getNumChannels(codecContext_); + TORCH_CHECK( + srcNumChannels == getNumChannels(srcAVFrame), + "The frame has ", + getNumChannels(srcAVFrame), + " channels, expected ", + srcNumChannels, + ". If you are hitting this, it may be because you are using " + "a buggy FFmpeg version. FFmpeg4 is known to fail here in some " + "valid scenarios. Try to upgrade FFmpeg?"); + int outNumChannels = audioStreamOptions_.numChannels.value_or(srcNumChannels); + + bool mustConvert = + (srcSampleFormat != outSampleFormat || srcSampleRate != outSampleRate || + srcNumChannels != outNumChannels); + + UniqueAVFrame convertedAVFrame; + if (mustConvert) { + if (!swrContext_) { + swrContext_.reset(createSwrContext( + srcSampleFormat, + outSampleFormat, + srcSampleRate, + outSampleRate, + srcAVFrame, + outNumChannels)); + } + + convertedAVFrame = convertAudioAVFrameSamples( + swrContext_, + srcAVFrame, + outSampleFormat, + outSampleRate, + outNumChannels); + } + const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; + + AVSampleFormat format = static_cast(avFrame->format); + TORCH_CHECK( + format == outSampleFormat, + "Something went wrong, the frame didn't get converted to the desired format. ", + "Desired format = ", + av_get_sample_fmt_name(outSampleFormat), + "source format = ", + av_get_sample_fmt_name(format)); + + int numChannels = getNumChannels(avFrame); + TORCH_CHECK( + numChannels == outNumChannels, + "Something went wrong, the frame didn't get converted to the desired ", + "number of channels = ", + outNumChannels, + ". Got ", + numChannels, + " instead."); + + auto numSamples = avFrame->nb_samples; + + frameOutput.data = torch::empty({numChannels, numSamples}, torch::kFloat32); + + if (numSamples > 0) { + uint8_t* outputChannelData = + static_cast(frameOutput.data.data_ptr()); + auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format); + for (auto channel = 0; channel < numChannels; + ++channel, outputChannelData += numBytesPerChannel) { + std::memcpy( + outputChannelData, + avFrame->extended_data[channel], + numBytesPerChannel); + } + } +} + +std::optional CpuDeviceInterface::maybeFlushAudioBuffers() { + // When sample rate conversion is involved, swresample buffers some of the + // samples in-between calls to swr_convert (see the libswresample docs). + // That's because the last few samples in a given frame require future + // samples from the next frame to be properly converted. This function + // flushes out the samples that are stored in swresample's buffers. + if (!swrContext_) { + return std::nullopt; + } + auto numRemainingSamples = swr_get_out_samples(swrContext_.get(), 0); + + if (numRemainingSamples == 0) { + return std::nullopt; + } + + int numChannels = + audioStreamOptions_.numChannels.value_or(getNumChannels(codecContext_)); + torch::Tensor lastSamples = + torch::empty({numChannels, numRemainingSamples}, torch::kFloat32); + + std::vector outputBuffers(numChannels); + for (auto i = 0; i < numChannels; i++) { + outputBuffers[i] = static_cast(lastSamples[i].data_ptr()); + } + + auto actualNumRemainingSamples = swr_convert( + swrContext_.get(), outputBuffers.data(), numRemainingSamples, nullptr, 0); + + return lastSamples.narrow( + /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples); +} + std::string CpuDeviceInterface::getDetails() { return std::string("CPU Device Interface."); } diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 3f6f7c962..6267acd4b 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -33,15 +33,30 @@ class CpuDeviceInterface : public DeviceInterface { const std::vector>& transforms, const std::optional& resizedOutputDims) override; + virtual void initializeAudio( + const AudioStreamOptions& audioStreamOptions) override; + + virtual std::optional maybeFlushAudioBuffers() override; + void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, + AVMediaType mediaType, std::optional preAllocatedOutputTensor = std::nullopt) override; std::string getDetails() override; private: + void convertAudioAVFrameToFrameOutput( + UniqueAVFrame& srcAVFrame, + FrameOutput& frameOutput); + + void convertVideoAVFrameToFrameOutput( + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor); + int convertAVFrameToTensorUsingSwScale( const UniqueAVFrame& avFrame, torch::Tensor& outputTensor, @@ -130,6 +145,10 @@ class CpuDeviceInterface : public DeviceInterface { bool userRequestedSwScale_; bool initialized_ = false; + + // Audio-specific members + AudioStreamOptions audioStreamOptions_; + UniqueSwrContext swrContext_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index be45050e6..99d0ddad9 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -238,6 +238,7 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( void CudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, + [[maybe_unused]] AVMediaType mediaType, std::optional preAllocatedOutputTensor) { validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame); @@ -271,7 +272,8 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( } else { // Reason 2 above. We need to do a full conversion which requires an // actual CPU device. - cpuInterface_->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); + cpuInterface_->convertAVFrameToFrameOutput( + avFrame, cpuFrameOutput, AVMEDIA_TYPE_VIDEO); } // Finally, we need to send the frame back to the GPU. Note that the diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 9f171ee3c..ab5c58ce7 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -37,6 +37,7 @@ class CudaDeviceInterface : public DeviceInterface { void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, + AVMediaType mediaType, std::optional preAllocatedOutputTensor = std::nullopt) override; diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 773317e83..e13756122 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -65,6 +65,21 @@ class DeviceInterface { transforms, [[maybe_unused]] const std::optional& resizedOutputDims) {} + // Initialize the device with parameters specific to audio decoding. There is + // a default empty implementation. + virtual void initializeAudio( + [[maybe_unused]] const AudioStreamOptions& audioStreamOptions) {} + + // Flush any remaining samples from the audio resampler buffer. + // When sample rate conversion is involved, some samples may be buffered + // between frames for proper interpolation. This function flushes those + // buffered samples. + // Returns an optional tensor containing the flushed samples, or std::nullopt + // if there are no buffered samples or audio is not supported. + virtual std::optional maybeFlushAudioBuffers() { + return std::nullopt; + } + // In order for decoding to actually happen on an FFmpeg managed hardware // device, we need to register the DeviceInterface managed // AVHardwareDeviceContext with the AVCodecContext. We don't need to do this @@ -75,6 +90,7 @@ class DeviceInterface { virtual void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, + AVMediaType mediaType, std::optional preAllocatedOutputTensor = std::nullopt) = 0; // ------------------------------------------ diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 8d9e9f651..29568e326 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -570,6 +570,9 @@ void SingleStreamDecoder::addAudioStream( // support that format, but it looks like it does nothing, so this probably // doesn't hurt. streamInfo.codecContext->request_sample_fmt = AV_SAMPLE_FMT_FLTP; + + // Initialize device interface for audio + deviceInterface_->initializeAudio(audioStreamOptions); } // -------------------------------------------------------------------------- @@ -1025,7 +1028,7 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( (stopPts <= lastDecodedAvFrameEnd); } - auto lastSamples = maybeFlushSwrBuffers(); + auto lastSamples = deviceInterface_->maybeFlushAudioBuffers(); if (lastSamples.has_value()) { frames.push_back(*lastSamples); } @@ -1293,136 +1296,11 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput( frameOutput.durationSeconds = ptsToSeconds( getDuration(avFrame), formatContext_->streams[activeStreamIndex_]->time_base); - if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { - convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput); - } else { - deviceInterface_->convertAVFrameToFrameOutput( - avFrame, frameOutput, preAllocatedOutputTensor); - } + deviceInterface_->convertAVFrameToFrameOutput( + avFrame, frameOutput, streamInfo.avMediaType, preAllocatedOutputTensor); return frameOutput; } -void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( - UniqueAVFrame& srcAVFrame, - FrameOutput& frameOutput) { - AVSampleFormat srcSampleFormat = - static_cast(srcAVFrame->format); - AVSampleFormat outSampleFormat = AV_SAMPLE_FMT_FLTP; - - StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - int srcSampleRate = srcAVFrame->sample_rate; - int outSampleRate = - streamInfo.audioStreamOptions.sampleRate.value_or(srcSampleRate); - - int srcNumChannels = getNumChannels(streamInfo.codecContext); - TORCH_CHECK( - srcNumChannels == getNumChannels(srcAVFrame), - "The frame has ", - getNumChannels(srcAVFrame), - " channels, expected ", - srcNumChannels, - ". If you are hitting this, it may be because you are using " - "a buggy FFmpeg version. FFmpeg4 is known to fail here in some " - "valid scenarios. Try to upgrade FFmpeg?"); - int outNumChannels = - streamInfo.audioStreamOptions.numChannels.value_or(srcNumChannels); - - bool mustConvert = - (srcSampleFormat != outSampleFormat || srcSampleRate != outSampleRate || - srcNumChannels != outNumChannels); - - UniqueAVFrame convertedAVFrame; - if (mustConvert) { - if (!swrContext_) { - swrContext_.reset(createSwrContext( - srcSampleFormat, - outSampleFormat, - srcSampleRate, - outSampleRate, - srcAVFrame, - outNumChannels)); - } - - convertedAVFrame = convertAudioAVFrameSamples( - swrContext_, - srcAVFrame, - outSampleFormat, - outSampleRate, - outNumChannels); - } - const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; - - AVSampleFormat format = static_cast(avFrame->format); - TORCH_CHECK( - format == outSampleFormat, - "Something went wrong, the frame didn't get converted to the desired format. ", - "Desired format = ", - av_get_sample_fmt_name(outSampleFormat), - "source format = ", - av_get_sample_fmt_name(format)); - - int numChannels = getNumChannels(avFrame); - TORCH_CHECK( - numChannels == outNumChannels, - "Something went wrong, the frame didn't get converted to the desired ", - "number of channels = ", - outNumChannels, - ". Got ", - numChannels, - " instead."); - - auto numSamples = avFrame->nb_samples; // per channel - - frameOutput.data = torch::empty({numChannels, numSamples}, torch::kFloat32); - - if (numSamples > 0) { - uint8_t* outputChannelData = - static_cast(frameOutput.data.data_ptr()); - auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format); - for (auto channel = 0; channel < numChannels; - ++channel, outputChannelData += numBytesPerChannel) { - std::memcpy( - outputChannelData, - avFrame->extended_data[channel], - numBytesPerChannel); - } - } -} - -std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { - // When sample rate conversion is involved, swresample buffers some of the - // samples in-between calls to swr_convert (see the libswresample docs). - // That's because the last few samples in a given frame require future - // samples from the next frame to be properly converted. This function - // flushes out the samples that are stored in swresample's buffers. - auto& streamInfo = streamInfos_[activeStreamIndex_]; - if (!swrContext_) { - return std::nullopt; - } - auto numRemainingSamples = // this is an upper bound - swr_get_out_samples(swrContext_.get(), 0); - - if (numRemainingSamples == 0) { - return std::nullopt; - } - - int numChannels = streamInfo.audioStreamOptions.numChannels.value_or( - getNumChannels(streamInfo.codecContext)); - torch::Tensor lastSamples = - torch::empty({numChannels, numRemainingSamples}, torch::kFloat32); - - std::vector outputBuffers(numChannels); - for (auto i = 0; i < numChannels; i++) { - outputBuffers[i] = static_cast(lastSamples[i].data_ptr()); - } - - auto actualNumRemainingSamples = swr_convert( - swrContext_.get(), outputBuffers.data(), numRemainingSamples, nullptr, 0); - - return lastSamples.narrow( - /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples); -} - // -------------------------------------------------------------------------- // OUTPUT ALLOCATION AND SHAPE CONVERSION // -------------------------------------------------------------------------- diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 4d4c11aa2..49186d8d2 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -273,10 +273,6 @@ class SingleStreamDecoder { FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt); - void convertAudioAVFrameToFrameOutputOnCPU( - UniqueAVFrame& srcAVFrame, - FrameOutput& frameOutput); - torch::Tensor convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame); @@ -284,8 +280,6 @@ class SingleStreamDecoder { const UniqueAVFrame& avFrame, torch::Tensor& outputTensor); - std::optional maybeFlushSwrBuffers(); - // -------------------------------------------------------------------------- // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- @@ -359,11 +353,6 @@ class SingleStreamDecoder { int64_t lastDecodedAvFramePts_ = 0; int64_t lastDecodedAvFrameDuration_ = 0; - // Audio only. We cache it for performance. The video equivalents live in - // deviceInterface_. We store swrContext_ here because we only handle audio - // on the CPU. - UniqueSwrContext swrContext_; - // Stores various internal decoding stats. DecodeStats decodeStats_; From 685583ce6852bd9da333185708f7547e717fe3b4 Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Thu, 30 Oct 2025 13:44:07 -0700 Subject: [PATCH 2/2] address comments --- src/torchcodec/_core/BetaCudaDeviceInterface.cpp | 1 - src/torchcodec/_core/BetaCudaDeviceInterface.h | 4 +--- src/torchcodec/_core/CpuDeviceInterface.cpp | 8 +++++--- src/torchcodec/_core/CpuDeviceInterface.h | 4 +--- src/torchcodec/_core/CudaDeviceInterface.cpp | 4 +--- src/torchcodec/_core/CudaDeviceInterface.h | 4 +--- src/torchcodec/_core/DeviceInterface.h | 2 +- src/torchcodec/_core/SingleStreamDecoder.cpp | 3 +-- 8 files changed, 11 insertions(+), 19 deletions(-) diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 753d239c1..587456f34 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -814,7 +814,6 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12( void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - [[maybe_unused]] AVMediaType mediaType, std::optional preAllocatedOutputTensor) { UniqueAVFrame gpuFrame = cpuFallback_ ? transferCpuFrameToGpuNV12(avFrame) : std::move(avFrame); diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index d53ca387a..747add12e 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -46,9 +46,7 @@ class BetaCudaDeviceInterface : public DeviceInterface { void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - AVMediaType mediaType, - std::optional preAllocatedOutputTensor = - std::nullopt) override; + std::optional preAllocatedOutputTensor) override; int sendPacket(ReferenceAVPacket& packet) override; int sendEOFPacket() override; diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index cef696efc..074979f65 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -35,6 +35,7 @@ void CpuDeviceInterface::initializeVideo( const VideoStreamOptions& videoStreamOptions, const std::vector>& transforms, const std::optional& resizedOutputDims) { + avMediaType_ = AVMEDIA_TYPE_VIDEO; videoStreamOptions_ = videoStreamOptions; resizedOutputDims_ = resizedOutputDims; @@ -88,6 +89,7 @@ void CpuDeviceInterface::initializeVideo( void CpuDeviceInterface::initializeAudio( const AudioStreamOptions& audioStreamOptions) { + avMediaType_ = AVMEDIA_TYPE_AUDIO; audioStreamOptions_ = audioStreamOptions; initialized_ = true; } @@ -123,11 +125,10 @@ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary( void CpuDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - AVMediaType mediaType, std::optional preAllocatedOutputTensor) { TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized."); - if (mediaType == AVMEDIA_TYPE_AUDIO) { + if (avMediaType_ == AVMEDIA_TYPE_AUDIO) { convertAudioAVFrameToFrameOutput(avFrame, frameOutput); } else { convertVideoAVFrameToFrameOutput( @@ -390,7 +391,8 @@ std::optional CpuDeviceInterface::maybeFlushAudioBuffers() { if (!swrContext_) { return std::nullopt; } - auto numRemainingSamples = swr_get_out_samples(swrContext_.get(), 0); + auto numRemainingSamples = // this is an upper bound + swr_get_out_samples(swrContext_.get(), 0); if (numRemainingSamples == 0) { return std::nullopt; diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 812c80c7c..1b92cdf4b 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -41,9 +41,7 @@ class CpuDeviceInterface : public DeviceInterface { void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - AVMediaType mediaType, - std::optional preAllocatedOutputTensor = - std::nullopt) override; + std::optional preAllocatedOutputTensor) override; std::string getDetails() override; diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 99d0ddad9..be45050e6 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -238,7 +238,6 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( void CudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - [[maybe_unused]] AVMediaType mediaType, std::optional preAllocatedOutputTensor) { validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame); @@ -272,8 +271,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( } else { // Reason 2 above. We need to do a full conversion which requires an // actual CPU device. - cpuInterface_->convertAVFrameToFrameOutput( - avFrame, cpuFrameOutput, AVMEDIA_TYPE_VIDEO); + cpuInterface_->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); } // Finally, we need to send the frame back to the GPU. Note that the diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index ab5c58ce7..4ab658aad 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -37,9 +37,7 @@ class CudaDeviceInterface : public DeviceInterface { void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - AVMediaType mediaType, - std::optional preAllocatedOutputTensor = - std::nullopt) override; + std::optional preAllocatedOutputTensor) override; std::string getDetails() override; diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index e13756122..34664c2d4 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -90,7 +90,6 @@ class DeviceInterface { virtual void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, - AVMediaType mediaType, std::optional preAllocatedOutputTensor = std::nullopt) = 0; // ------------------------------------------ @@ -142,6 +141,7 @@ class DeviceInterface { protected: torch::Device device_; SharedAVCodecContext codecContext_; + AVMediaType avMediaType_; }; using CreateDeviceInterfaceFn = diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 29568e326..bc2a701f0 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -1289,7 +1289,6 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput( std::optional preAllocatedOutputTensor) { // Convert the frame to tensor. FrameOutput frameOutput; - auto& streamInfo = streamInfos_[activeStreamIndex_]; frameOutput.ptsSeconds = ptsToSeconds( getPtsOrDts(avFrame), formatContext_->streams[activeStreamIndex_]->time_base); @@ -1297,7 +1296,7 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput( getDuration(avFrame), formatContext_->streams[activeStreamIndex_]->time_base); deviceInterface_->convertAVFrameToFrameOutput( - avFrame, frameOutput, streamInfo.avMediaType, preAllocatedOutputTensor); + avFrame, frameOutput, preAllocatedOutputTensor); return frameOutput; }