-
Notifications
You must be signed in to change notification settings - Fork 75
Add support for CUDA >= 12.9 #757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
886f64a
Update NPP calls for CUDA >= 12.9
Kh4L 6023ea3
Add testing against CUDA 12.9
NicolasHug d66cb33
Linter
NicolasHug 69753e0
Merge branch 'cuda129' into cuda129_update
NicolasHug ecf01a9
Move nppContext creation into separate function. Also rely on device_
NicolasHug 565896e
Use cache for nppContext object
NicolasHug 9a6d3d3
Add maybe_unused
NicolasHug 2d681ad
Pass positive index
NicolasHug 4868724
Merge branch 'main' of github.com:pytorch/torchcodec into cuda129_update
NicolasHug 320c060
Try manual creation for all CUDA versions
NicolasHug 7056fc0
remove cache, it should be per-device not per decoder instance. Leaving
NicolasHug 9ddc670
Revert "remove cache, it should be per-device not per decoder instanc…
NicolasHug 8853ec8
Merge branch 'main' of github.com:pytorch/torchcodec into cuda129_update
NicolasHug c27d4b5
Remove deviceGuard
NicolasHug 5306ca4
Revert "Remove deviceGuard"
NicolasHug b454c0c
Reapply "remove cache, it should be per-device not per decoder instan…
NicolasHug 8cedbde
remove comment
NicolasHug File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -224,41 +224,66 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( | |
| // Use the user-requested GPU for running the NPP kernel. | ||
| c10::cuda::CUDAGuard deviceGuard(device_); | ||
|
|
||
| cudaStream_t rawStream = at::cuda::getCurrentCUDAStream().stream(); | ||
|
|
||
| // Build an NppStreamContext, either via the old helper or by hand on | ||
| // CUDA 12.9+ | ||
| NppStreamContext nppCtx{}; | ||
| #if CUDA_VERSION < 12090 | ||
| NppStatus ctxStat = nppGetStreamContext(&nppCtx); | ||
| TORCH_CHECK(ctxStat == NPP_SUCCESS, "nppGetStreamContext failed"); | ||
| // override if you want to force a particular stream | ||
| nppCtx.hStream = rawStream; | ||
| #else | ||
| // CUDA 12.9+: helper was removed, we need to build it manually | ||
| int dev = 0; | ||
| cudaError_t err = cudaGetDevice(&dev); | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| TORCH_CHECK(err == cudaSuccess, "cudaGetDevice failed"); | ||
| cudaDeviceProp prop{}; | ||
| err = cudaGetDeviceProperties(&prop, dev); | ||
| TORCH_CHECK(err == cudaSuccess, "cudaGetDeviceProperties failed"); | ||
|
|
||
| nppCtx.nCudaDeviceId = dev; | ||
| nppCtx.nMultiProcessorCount = prop.multiProcessorCount; | ||
| nppCtx.nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor; | ||
| nppCtx.nMaxThreadsPerBlock = prop.maxThreadsPerBlock; | ||
| nppCtx.nSharedMemPerBlock = prop.sharedMemPerBlock; | ||
| nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major; | ||
| nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor; | ||
| nppCtx.nStreamFlags = 0; | ||
| nppCtx.hStream = rawStream; | ||
| #endif | ||
|
|
||
| // Prepare ROI + pointers | ||
| NppiSize oSizeROI = {width, height}; | ||
| Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]}; | ||
|
|
||
| auto start = std::chrono::high_resolution_clock::now(); | ||
| NppStatus status; | ||
|
|
||
| if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) { | ||
| status = nppiNV12ToRGB_709CSC_8u_P2C3R( | ||
| status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx( | ||
| input, | ||
| avFrame->linesize[0], | ||
| static_cast<Npp8u*>(dst.data_ptr()), | ||
| dst.stride(0), | ||
| oSizeROI); | ||
| oSizeROI, | ||
| nppCtx); | ||
| } else { | ||
| status = nppiNV12ToRGB_8u_P2C3R( | ||
| status = nppiNV12ToRGB_8u_P2C3R_Ctx( | ||
| input, | ||
| avFrame->linesize[0], | ||
| static_cast<Npp8u*>(dst.data_ptr()), | ||
| dst.stride(0), | ||
| oSizeROI); | ||
| oSizeROI, | ||
| nppCtx); | ||
| } | ||
| TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); | ||
|
|
||
| // Make the pytorch stream wait for the npp kernel to finish before using the | ||
| // output. | ||
| at::cuda::CUDAEvent nppDoneEvent; | ||
| at::cuda::CUDAStream nppStreamWrapper = | ||
| c10::cuda::getStreamFromExternal(nppGetStream(), device_.index()); | ||
| nppDoneEvent.record(nppStreamWrapper); | ||
| nppDoneEvent.block(at::cuda::getCurrentCUDAStream()); | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These syncs aren't needed anymore because we now explicitly ask |
||
|
|
||
| auto end = std::chrono::high_resolution_clock::now(); | ||
|
|
||
| std::chrono::duration<double, std::micro> duration = end - start; | ||
| VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width | ||
| << " took: " << duration.count() << "us" << std::endl; | ||
| auto duration = std::chrono::duration<double, std::micro>(end - start); | ||
| VLOG(9) << "NPP Conversion of frame h=" << height << " w=" << width | ||
| << " took: " << duration.count() << "us"; | ||
| } | ||
|
|
||
| // inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This guard isn't needed anymore as we now explicitly pass the current device to the NppContext creation.