@@ -174,6 +174,14 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
174174 TORCH_CHECK (g_cuda, " CudaDeviceInterface was not registered!" );
175175 TORCH_CHECK (
176176 device_.type () == torch::kCUDA , " Unsupported device: " , device_.str ());
177+
178+ // It is important for pytorch itself to create the cuda context. If ffmpeg
179+ // creates the context it may not be compatible with pytorch.
180+ // This is a dummy tensor to initialize the cuda context.
181+ torch::Tensor dummyTensorForCudaInitialization = torch::empty (
182+ {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
183+ ctx_ = getCudaContext (device_);
184+ nppCtx_ = getNppStreamContext (device_);
177185}
178186
179187CudaDeviceInterface::~CudaDeviceInterface () {
@@ -191,20 +199,12 @@ void CudaDeviceInterface::initialize(
191199 [[maybe_unused]] const std::vector<std::unique_ptr<Transform>>& transforms,
192200 const AVRational& timeBase,
193201 [[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims) {
194- TORCH_CHECK (! ctx_, " FFmpeg HW device context already initialized" );
202+ TORCH_CHECK (ctx_, " FFmpeg HW device has not been initialized" );
195203 TORCH_CHECK (codecContext != nullptr , " codecContext is null" );
196204
205+ codecContext->hw_device_ctx = av_buffer_ref (ctx_.get ());
197206 videoStreamOptions_ = videoStreamOptions;
198207 timeBase_ = timeBase;
199-
200- // It is important for pytorch itself to create the cuda context. If ffmpeg
201- // creates the context it may not be compatible with pytorch.
202- // This is a dummy tensor to initialize the cuda context.
203- torch::Tensor dummyTensorForCudaInitialization = torch::empty (
204- {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
205- ctx_ = getCudaContext (device_);
206- nppCtx_ = getNppStreamContext (device_);
207- codecContext->hw_device_ctx = av_buffer_ref (ctx_.get ());
208208}
209209
210210UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12 (
@@ -304,7 +304,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
304304 std::optional<torch::Tensor> preAllocatedOutputTensor) {
305305 // Note that CUDA does not yet support transforms, so the only possible
306306 // frame dimensions are the raw decoded frame's dimensions.
307- auto frameDims = FrameDims (avFrame->width , avFrame->height );
307+ auto frameDims = FrameDims (avFrame->height , avFrame->width );
308308
309309 if (preAllocatedOutputTensor.has_value ()) {
310310 auto shape = preAllocatedOutputTensor.value ().sizes ();
@@ -379,14 +379,15 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
379379
380380 // Above we checked that the AVFrame was on GPU, but that's not enough, we
381381 // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
382- // because this is what the NPP color conversion routines expect.
382+ // because this is what the NPP color conversion routines expect. This SHOULD
383+ // be enforced by our call to maybeConvertAVFrameToNV12() above.
384+ auto hwFramesCtx =
385+ reinterpret_cast <AVHWFramesContext*>(avFrame->hw_frames_ctx ->data );
383386 TORCH_CHECK (
384- avFrame-> hw_frames_ctx != nullptr ,
387+ hwFramesCtx != nullptr ,
385388 " The AVFrame does not have a hw_frames_ctx. "
386389 " That's unexpected, please report this to the TorchCodec repo." );
387390
388- auto hwFramesCtx =
389- reinterpret_cast <AVHWFramesContext*>(avFrame->hw_frames_ctx ->data );
390391 AVPixelFormat actualFormat = hwFramesCtx->sw_format ;
391392
392393 TORCH_CHECK (
0 commit comments