Skip to content

Commit 886f64a

Browse files
committed
Update NPP calls for CUDA >= 12.9
1 parent 103f714 commit 886f64a

File tree

1 file changed

+43
-17
lines changed

1 file changed

+43
-17
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -224,41 +224,67 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
224224
// Use the user-requested GPU for running the NPP kernel.
225225
c10::cuda::CUDAGuard deviceGuard(device_);
226226

227-
NppiSize oSizeROI = {width, height};
228-
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};
227+
cudaStream_t rawStream = at::cuda::getCurrentCUDAStream().stream();
228+
229+
// Build an NppStreamContext, either via the old helper or by hand on CUDA 12.9+
230+
NppStreamContext nppCtx{};
231+
#if CUDA_VERSION < 12090
232+
NppStatus ctxStat = nppGetStreamContext(&nppCtx);
233+
TORCH_CHECK(ctxStat == NPP_SUCCESS, "nppGetStreamContext failed");
234+
// override if you want to force a particular stream
235+
nppCtx.hStream = rawStream;
236+
#else
237+
// CUDA 12.9+: helper was removed, we need to build it manually
238+
int dev = 0;
239+
cudaError_t err = cudaGetDevice(&dev);
240+
TORCH_CHECK(err == cudaSuccess, "cudaGetDevice failed");
241+
cudaDeviceProp prop{};
242+
err = cudaGetDeviceProperties(&prop, dev);
243+
TORCH_CHECK(err == cudaSuccess, "cudaGetDeviceProperties failed");
244+
245+
nppCtx.nCudaDeviceId = dev;
246+
nppCtx.nMultiProcessorCount = prop.multiProcessorCount;
247+
nppCtx.nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor;
248+
nppCtx.nMaxThreadsPerBlock = prop.maxThreadsPerBlock;
249+
nppCtx.nSharedMemPerBlock = prop.sharedMemPerBlock;
250+
nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major;
251+
nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor;
252+
nppCtx.nStreamFlags = 0;
253+
nppCtx.hStream = rawStream;
254+
#endif
255+
256+
// Prepare ROI + pointers
257+
NppiSize oSizeROI = { width, height };
258+
Npp8u* input[2] = { avFrame->data[0], avFrame->data[1] };
229259

230260
auto start = std::chrono::high_resolution_clock::now();
231261
NppStatus status;
262+
232263
if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
233-
status = nppiNV12ToRGB_709CSC_8u_P2C3R(
264+
status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx(
234265
input,
235266
avFrame->linesize[0],
236267
static_cast<Npp8u*>(dst.data_ptr()),
237268
dst.stride(0),
238-
oSizeROI);
269+
oSizeROI,
270+
nppCtx);
239271
} else {
240-
status = nppiNV12ToRGB_8u_P2C3R(
272+
status = nppiNV12ToRGB_8u_P2C3R_Ctx(
241273
input,
242274
avFrame->linesize[0],
243275
static_cast<Npp8u*>(dst.data_ptr()),
244276
dst.stride(0),
245-
oSizeROI);
277+
oSizeROI,
278+
nppCtx);
246279
}
247280
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame.");
248281

249-
// Make the pytorch stream wait for the npp kernel to finish before using the
250-
// output.
251-
at::cuda::CUDAEvent nppDoneEvent;
252-
at::cuda::CUDAStream nppStreamWrapper =
253-
c10::cuda::getStreamFromExternal(nppGetStream(), device_.index());
254-
nppDoneEvent.record(nppStreamWrapper);
255-
nppDoneEvent.block(at::cuda::getCurrentCUDAStream());
256-
257282
auto end = std::chrono::high_resolution_clock::now();
283+
auto duration = std::chrono::duration<double, std::micro>(end - start);
284+
VLOG(9) << "NPP Conversion of frame h=" << height
285+
<< " w=" << width
286+
<< " took: " << duration.count() << "us";
258287

259-
std::chrono::duration<double, std::micro> duration = end - start;
260-
VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width
261-
<< " took: " << duration.count() << "us" << std::endl;
262288
}
263289

264290
// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9

0 commit comments

Comments
 (0)