Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/linux_cuda_wheel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ jobs:
# For the actual release we should add that label and change this to
# include more python versions.
python-version: ['3.9']
cuda-version: ['12.6', '12.8']
# We test against 12.6 and 12.9 to avoid having too big of a CI matrix,
# but for releases we should add 12.8.
cuda-version: ['12.6', '12.9']
# TODO: put back ffmpeg 5 https://github.com/pytorch/torchcodec/issues/325
ffmpeg-version-for-tests: ['4.4.2', '6', '7']

Expand Down
57 changes: 41 additions & 16 deletions src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,41 +224,66 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
// Use the user-requested GPU for running the NPP kernel.
c10::cuda::CUDAGuard deviceGuard(device_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guard isn't needed anymore as we now explicitly pass the current device to the NppContext creation.


cudaStream_t rawStream = at::cuda::getCurrentCUDAStream().stream();

// Build an NppStreamContext, either via the old helper or by hand on
// CUDA 12.9+
NppStreamContext nppCtx{};
#if CUDA_VERSION < 12090
NppStatus ctxStat = nppGetStreamContext(&nppCtx);
TORCH_CHECK(ctxStat == NPP_SUCCESS, "nppGetStreamContext failed");
// override if you want to force a particular stream
nppCtx.hStream = rawStream;
#else
// CUDA 12.9+: helper was removed, we need to build it manually
int dev = 0;
cudaError_t err = cudaGetDevice(&dev);
TORCH_CHECK(err == cudaSuccess, "cudaGetDevice failed");
cudaDeviceProp prop{};
err = cudaGetDeviceProperties(&prop, dev);
TORCH_CHECK(err == cudaSuccess, "cudaGetDeviceProperties failed");

nppCtx.nCudaDeviceId = dev;
nppCtx.nMultiProcessorCount = prop.multiProcessorCount;
nppCtx.nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor;
nppCtx.nMaxThreadsPerBlock = prop.maxThreadsPerBlock;
nppCtx.nSharedMemPerBlock = prop.sharedMemPerBlock;
nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major;
nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor;
nppCtx.nStreamFlags = 0;
nppCtx.hStream = rawStream;
#endif

// Prepare ROI + pointers
NppiSize oSizeROI = {width, height};
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};

auto start = std::chrono::high_resolution_clock::now();
NppStatus status;

if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
status = nppiNV12ToRGB_709CSC_8u_P2C3R(
status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx(
input,
avFrame->linesize[0],
static_cast<Npp8u*>(dst.data_ptr()),
dst.stride(0),
oSizeROI);
oSizeROI,
nppCtx);
} else {
status = nppiNV12ToRGB_8u_P2C3R(
status = nppiNV12ToRGB_8u_P2C3R_Ctx(
input,
avFrame->linesize[0],
static_cast<Npp8u*>(dst.data_ptr()),
dst.stride(0),
oSizeROI);
oSizeROI,
nppCtx);
}
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame.");

// Make the pytorch stream wait for the npp kernel to finish before using the
// output.
at::cuda::CUDAEvent nppDoneEvent;
at::cuda::CUDAStream nppStreamWrapper =
c10::cuda::getStreamFromExternal(nppGetStream(), device_.index());
nppDoneEvent.record(nppStreamWrapper);
nppDoneEvent.block(at::cuda::getCurrentCUDAStream());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These syncs aren't needed anymore because we now explicitly ask Npp to rely on pytorch's current stream.


auto end = std::chrono::high_resolution_clock::now();

std::chrono::duration<double, std::micro> duration = end - start;
VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width
<< " took: " << duration.count() << "us" << std::endl;
auto duration = std::chrono::duration<double, std::micro>(end - start);
VLOG(9) << "NPP Conversion of frame h=" << height << " w=" << width
<< " took: " << duration.count() << "us";
}

// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9
Expand Down
Loading