@@ -195,18 +195,20 @@ CudaDeviceInterface::~CudaDeviceInterface() {
195195
196196void CudaDeviceInterface::initialize (
197197 AVCodecContext* codecContext,
198- const VideoStreamOptions& videoStreamOptions,
199- [[maybe_unused]] const std::vector<std::unique_ptr<Transform>>& transforms,
200- const AVRational& timeBase,
201- [[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims) {
198+ const AVRational& timeBase) {
202199 TORCH_CHECK (ctx_, " FFmpeg HW device has not been initialized" );
203200 TORCH_CHECK (codecContext != nullptr , " codecContext is null" );
204-
205201 codecContext->hw_device_ctx = av_buffer_ref (ctx_.get ());
206- videoStreamOptions_ = videoStreamOptions;
207202 timeBase_ = timeBase;
208203}
209204
205+ void CudaDeviceInterface::initializeVideo (
206+ const VideoStreamOptions& videoStreamOptions,
207+ [[maybe_unused]] const std::vector<std::unique_ptr<Transform>>& transforms,
208+ [[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims) {
209+ videoStreamOptions_ = videoStreamOptions;
210+ }
211+
210212UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12 (
211213 UniqueAVFrame& avFrame) {
212214 // We need FFmpeg filters to handle those conversion cases which are not
@@ -220,13 +222,13 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12(
220222 return std::move (avFrame);
221223 }
222224
225+ auto hwFramesCtx =
226+ reinterpret_cast <AVHWFramesContext*>(avFrame->hw_frames_ctx ->data );
223227 TORCH_CHECK (
224- avFrame-> hw_frames_ctx != nullptr ,
228+ hwFramesCtx != nullptr ,
225229 " The AVFrame does not have a hw_frames_ctx. "
226230 " That's unexpected, please report this to the TorchCodec repo." );
227231
228- auto hwFramesCtx =
229- reinterpret_cast <AVHWFramesContext*>(avFrame->hw_frames_ctx ->data );
230232 AVPixelFormat actualFormat = hwFramesCtx->sw_format ;
231233
232234 // If the frame is already in NV12 format, we don't need to do anything.
@@ -355,10 +357,10 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
355357 TORCH_CHECK (
356358 cpuInterface != nullptr , " Failed to create CPU device interface" );
357359 cpuInterface->initialize (
358- nullptr ,
360+ /* codecContext=*/ nullptr , timeBase_);
361+ cpuInterface->initializeVideo (
359362 VideoStreamOptions (),
360363 {},
361- timeBase_,
362364 /* resizedOutputDims=*/ std::nullopt );
363365
364366 cpuInterface->convertAVFrameToFrameOutput (avFrame, cpuFrameOutput);
0 commit comments