@@ -224,41 +224,67 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
224224 // Use the user-requested GPU for running the NPP kernel.
225225 c10::cuda::CUDAGuard deviceGuard (device_);
226226
227- NppiSize oSizeROI = {width, height};
228- Npp8u* input[2 ] = {avFrame->data [0 ], avFrame->data [1 ]};
227+ cudaStream_t rawStream = at::cuda::getCurrentCUDAStream ().stream ();
228+
229+ // Build an NppStreamContext, either via the old helper or by hand on CUDA 12.9+
230+ NppStreamContext nppCtx{};
231+ #if CUDA_VERSION < 12090
232+ NppStatus ctxStat = nppGetStreamContext (&nppCtx);
233+ TORCH_CHECK (ctxStat == NPP_SUCCESS, " nppGetStreamContext failed" );
234+ // override if you want to force a particular stream
235+ nppCtx.hStream = rawStream;
236+ #else
237+ // CUDA 12.9+: helper was removed, we need to build it manually
238+ int dev = 0 ;
239+ cudaError_t err = cudaGetDevice (&dev);
240+ TORCH_CHECK (err == cudaSuccess, " cudaGetDevice failed" );
241+ cudaDeviceProp prop{};
242+ err = cudaGetDeviceProperties (&prop, dev);
243+ TORCH_CHECK (err == cudaSuccess, " cudaGetDeviceProperties failed" );
244+
245+ nppCtx.nCudaDeviceId = dev;
246+ nppCtx.nMultiProcessorCount = prop.multiProcessorCount ;
247+ nppCtx.nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor ;
248+ nppCtx.nMaxThreadsPerBlock = prop.maxThreadsPerBlock ;
249+ nppCtx.nSharedMemPerBlock = prop.sharedMemPerBlock ;
250+ nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major ;
251+ nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor ;
252+ nppCtx.nStreamFlags = 0 ;
253+ nppCtx.hStream = rawStream;
254+ #endif
255+
256+ // Prepare ROI + pointers
257+ NppiSize oSizeROI = { width, height };
258+ Npp8u* input[2 ] = { avFrame->data [0 ], avFrame->data [1 ] };
229259
230260 auto start = std::chrono::high_resolution_clock::now ();
231261 NppStatus status;
262+
232263 if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
233- status = nppiNV12ToRGB_709CSC_8u_P2C3R (
264+ status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx (
234265 input,
235266 avFrame->linesize [0 ],
236267 static_cast <Npp8u*>(dst.data_ptr ()),
237268 dst.stride (0 ),
238- oSizeROI);
269+ oSizeROI,
270+ nppCtx);
239271 } else {
240- status = nppiNV12ToRGB_8u_P2C3R (
272+ status = nppiNV12ToRGB_8u_P2C3R_Ctx (
241273 input,
242274 avFrame->linesize [0 ],
243275 static_cast <Npp8u*>(dst.data_ptr ()),
244276 dst.stride (0 ),
245- oSizeROI);
277+ oSizeROI,
278+ nppCtx);
246279 }
247280 TORCH_CHECK (status == NPP_SUCCESS, " Failed to convert NV12 frame." );
248281
249- // Make the pytorch stream wait for the npp kernel to finish before using the
250- // output.
251- at::cuda::CUDAEvent nppDoneEvent;
252- at::cuda::CUDAStream nppStreamWrapper =
253- c10::cuda::getStreamFromExternal (nppGetStream (), device_.index ());
254- nppDoneEvent.record (nppStreamWrapper);
255- nppDoneEvent.block (at::cuda::getCurrentCUDAStream ());
256-
257282 auto end = std::chrono::high_resolution_clock::now ();
283+ auto duration = std::chrono::duration<double , std::micro>(end - start);
284+ VLOG (9 ) << " NPP Conversion of frame h=" << height
285+ << " w=" << width
286+ << " took: " << duration.count () << " us" ;
258287
259- std::chrono::duration<double , std::micro> duration = end - start;
260- VLOG (9 ) << " NPP Conversion of frame height=" << height << " width=" << width
261- << " took: " << duration.count () << " us" << std::endl;
262288}
263289
264290// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9
0 commit comments