11#include < ATen/cuda/CUDAGreenContext.h>
22
3- #if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
4- #include < c10/cuda/driver_api.h>
5- #include < stdexcept>
6- #include < vector>
7- #define HAS_CUDA_GREEN_CONTEXT () 1
8- #else
9- #define HAS_CUDA_GREEN_CONTEXT () 0
10- #endif
11-
123namespace at ::cuda {
4+ GreenContext::GreenContext (uint32_t device_id, uint32_t num_sms) {
5+ #if CUDA_HAS_GREEN_CONTEXT
6+ int driver_version;
7+ C10_CUDA_CHECK (cudaDriverGetVersion (&driver_version));
8+ TORCH_CHECK (
9+ driver_version >= 12080 , " cuda driver too old to use green context!" );
10+ CUcontext pctx = nullptr ;
11+ C10_CUDA_DRIVER_CHECK (c10::cuda::DriverAPI::get ()->cuCtxGetCurrent_ (&pctx));
12+ if (C10_UNLIKELY (!pctx)) {
13+ TORCH_WARN (
14+ " Attempted to create a green context but"
15+ " there was no primary context! Creating a primary context..." );
16+
17+ cudaFree (0 );
18+ }
1319
14- GreenContext::GreenContext (uint32_t device_id, uint32_t num_sms) {
15- #if HAS_CUDA_GREEN_CONTEXT()
16- int driver_version;
17- C10_CUDA_CHECK (cudaDriverGetVersion (&driver_version));
18- TORCH_CHECK (
19- driver_version >= 12080 , " cuda driver too old to use green context!" );
20- CUcontext pctx = nullptr ;
21- C10_CUDA_DRIVER_CHECK (c10::cuda::DriverAPI::get ()->cuCtxGetCurrent_ (&pctx));
22- if (C10_UNLIKELY (!pctx)) {
23- TORCH_WARN (
24- " Attempted to create a green context but"
25- " there was no primary context! Creating a primary context..." );
26-
27- cudaFree (0 );
28- }
20+ CUdevice device;
21+ device_id_ = device_id;
22+ C10_CUDA_DRIVER_CHECK (
23+ c10::cuda::DriverAPI::get ()->cuDeviceGet_ (&device, device_id));
24+
25+ // Get device resources
26+ CUdevResource device_resource;
27+ C10_CUDA_DRIVER_CHECK (c10::cuda::DriverAPI::get ()->cuDeviceGetDevResource_ (
28+ device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
29+
30+ // Split resources
31+ std::vector<CUdevResource> result (1 );
32+ auto result_data = result.data ();
33+ unsigned int nb_groups = 1 ;
34+ CUdevResource remaining;
35+
36+ C10_CUDA_DRIVER_CHECK (
37+ c10::cuda::DriverAPI::get ()->cuDevSmResourceSplitByCount_ (
38+ result_data,
39+ &nb_groups,
40+ &device_resource,
41+ &remaining,
42+ 0 , // default flags
43+ num_sms));
44+
45+ TORCH_CHECK (nb_groups == 1 , " Failed to create single resource group" );
46+
47+ // Generate resource descriptor
48+ CUdevResourceDesc desc;
49+ C10_CUDA_DRIVER_CHECK (
50+ c10::cuda::DriverAPI::get ()->cuDevResourceGenerateDesc_ (
51+ &desc, result_data, 1 ));
2952
30- CUdevice device;
31- device_id_ = device_id;
32- C10_CUDA_DRIVER_CHECK (
33- c10::cuda::DriverAPI::get ()->cuDeviceGet_ (&device, device_id));
34-
35- // Get device resources
36- CUdevResource device_resource;
37- C10_CUDA_DRIVER_CHECK (c10::cuda::DriverAPI::get ()->cuDeviceGetDevResource_ (
38- device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
39-
40- // Split resources
41- std::vector<CUdevResource> result (1 );
42- auto result_data = result.data ();
43- unsigned int nb_groups = 1 ;
44- CUdevResource remaining;
45-
46- C10_CUDA_DRIVER_CHECK (
47- c10::cuda::DriverAPI::get ()->cuDevSmResourceSplitByCount_ (
48- result_data,
49- &nb_groups,
50- &device_resource,
51- &remaining,
52- 0 , // default flags
53- num_sms));
54-
55- TORCH_CHECK (nb_groups == 1 , " Failed to create single resource group" );
56-
57- // Generate resource descriptor
58- CUdevResourceDesc desc;
59- C10_CUDA_DRIVER_CHECK (
60- c10::cuda::DriverAPI::get ()->cuDevResourceGenerateDesc_ (
61- &desc, result_data, 1 ));
62-
63- // Create green context
64- // CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
65- // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
66- C10_CUDA_DRIVER_CHECK (c10::cuda::DriverAPI::get ()->cuGreenCtxCreate_ (
67- &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
68-
69- // Convert to regular context
70- C10_CUDA_DRIVER_CHECK (
71- c10::cuda::DriverAPI::get ()->cuCtxFromGreenCtx_ (&context_, green_ctx_));
72- TORCH_CHECK (context_, " Green ctx conversion to regular ctx failed!" );
53+ // Create green context
54+ // CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
55+ // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
56+ C10_CUDA_DRIVER_CHECK (c10::cuda::DriverAPI::get ()->cuGreenCtxCreate_ (
57+ &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
58+
59+ // Convert to regular context
60+ C10_CUDA_DRIVER_CHECK (
61+ c10::cuda::DriverAPI::get ()->cuCtxFromGreenCtx_ (&context_, green_ctx_));
62+ TORCH_CHECK (context_, " Green ctx conversion to regular ctx failed!" );
7363#else
74- TORCH_CHECK (false , " Green Context is only supported on CUDA 12.8+!" );
64+ TORCH_CHECK (false , " Green Context is only supported on CUDA 12.8+!" );
7565#endif
7666 }
7767
7868 std::unique_ptr<GreenContext> GreenContext::create (
7969 uint32_t num_sms,
8070 std::optional<uint32_t > device_id) {
81- #if HAS_CUDA_GREEN_CONTEXT()
71+ #if CUDA_HAS_GREEN_CONTEXT
8272 if (!device_id.has_value ()) {
8373 device_id = at::cuda::current_device ();
8474 }
85- return std::unique_ptr <GreenContext>(new GreenContext ( device_id.value (), num_sms) );
75+ return std::make_unique <GreenContext>(device_id.value (), num_sms);
8676#else
8777 TORCH_CHECK (false , " Green Context is only supported on CUDA 12.8+!" );
8878#endif
8979 }
9080
9181 // Implement move operations
9282 GreenContext::GreenContext (GreenContext&& other) noexcept {
93- #if HAS_CUDA_GREEN_CONTEXT()
83+ #if CUDA_HAS_GREEN_CONTEXT
9484 device_id_ = std::exchange (other.device_id_ , -1 );
9585 green_ctx_ = std::exchange (other.green_ctx_ , nullptr );
9686 context_ = std::exchange (other.context_ , nullptr );
@@ -101,7 +91,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
10191 }
10292
10393 GreenContext& GreenContext::operator =(GreenContext&& other) noexcept {
104- #if HAS_CUDA_GREEN_CONTEXT()
94+ #if CUDA_HAS_GREEN_CONTEXT
10595 if (this != &other) {
10696 // Clean up current resources
10797 if (green_ctx_) {
@@ -130,17 +120,33 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
130120 }
131121
132122 GreenContext::~GreenContext () noexcept {
133- #if HAS_CUDA_GREEN_CONTEXT()
123+ #if CUDA_HAS_GREEN_CONTEXT
134124 C10_CUDA_DRIVER_CHECK (
135125 c10::cuda::DriverAPI::get ()->cuGreenCtxDestroy_ (green_ctx_));
136126#else
137127 TORCH_CHECK (false , " Green Context is only supported on CUDA 12.8+!" );
138128#endif
139129 }
140130
131+ // Get the underlying CUDA context
132+ CUcontext GreenContext::getContext () const {
133+ #if CUDA_HAS_GREEN_CONTEXT
134+ return context_;
135+ #else
136+ TORCH_CHECK (false , " Green Context is only supported on CUDA 12.8+!" );
137+ #endif
138+ }
139+
140+ // Get the underlying green context
141+ #if CUDA_HAS_GREEN_CONTEXT
142+ CUgreenCtx GreenContext::getGreenContext () const {
143+ return green_ctx_;
144+ }
145+ #endif
146+
141147 // Make this context current
142148 void GreenContext::setContext () {
143- #if HAS_CUDA_GREEN_CONTEXT()
149+ #if CUDA_HAS_GREEN_CONTEXT
144150 auto current_stream = c10::cuda::getCurrentCUDAStream ();
145151 parent_stream_ = current_stream.stream ();
146152
@@ -169,7 +175,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
169175 }
170176
171177 void GreenContext::popContext () {
172- #if HAS_CUDA_GREEN_CONTEXT()
178+ #if CUDA_HAS_GREEN_CONTEXT
173179 // see above note about stream being hardcoded to the default stream
174180 at::cuda::CUDAEvent ev;
175181 ev.record (c10::cuda::getCurrentCUDAStream ());
0 commit comments