Skip to content

Commit 5e7272b

Browse files
Revert "[BE] Move GreenContext implementation details to cpp (pytorch#166462)"
This reverts commit afaaaa3. Reverted pytorch#166462 on behalf of https://github.com/atalman due to multiple internal build failures ([comment](pytorch#166462 (comment)))
1 parent 1dd6b76 commit 5e7272b

File tree

2 files changed

+109
-88
lines changed

2 files changed

+109
-88
lines changed

aten/src/ATen/cuda/CUDAGreenContext.cpp

Lines changed: 81 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,86 @@
11
#include <ATen/cuda/CUDAGreenContext.h>
22

3-
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
4-
#include <c10/cuda/driver_api.h>
5-
#include <stdexcept>
6-
#include <vector>
7-
#define HAS_CUDA_GREEN_CONTEXT() 1
8-
#else
9-
#define HAS_CUDA_GREEN_CONTEXT() 0
10-
#endif
11-
123
namespace at::cuda {
4+
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
5+
#if CUDA_HAS_GREEN_CONTEXT
6+
int driver_version;
7+
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
8+
TORCH_CHECK(
9+
driver_version >= 12080, "cuda driver too old to use green context!");
10+
CUcontext pctx = nullptr;
11+
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
12+
if (C10_UNLIKELY(!pctx)) {
13+
TORCH_WARN(
14+
"Attempted to create a green context but"
15+
" there was no primary context! Creating a primary context...");
16+
17+
cudaFree(0);
18+
}
1319

14-
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
15-
#if HAS_CUDA_GREEN_CONTEXT()
16-
int driver_version;
17-
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
18-
TORCH_CHECK(
19-
driver_version >= 12080, "cuda driver too old to use green context!");
20-
CUcontext pctx = nullptr;
21-
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
22-
if (C10_UNLIKELY(!pctx)) {
23-
TORCH_WARN(
24-
"Attempted to create a green context but"
25-
" there was no primary context! Creating a primary context...");
26-
27-
cudaFree(0);
28-
}
20+
CUdevice device;
21+
device_id_ = device_id;
22+
C10_CUDA_DRIVER_CHECK(
23+
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
24+
25+
// Get device resources
26+
CUdevResource device_resource;
27+
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
28+
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
29+
30+
// Split resources
31+
std::vector<CUdevResource> result(1);
32+
auto result_data = result.data();
33+
unsigned int nb_groups = 1;
34+
CUdevResource remaining;
35+
36+
C10_CUDA_DRIVER_CHECK(
37+
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
38+
result_data,
39+
&nb_groups,
40+
&device_resource,
41+
&remaining,
42+
0, // default flags
43+
num_sms));
44+
45+
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
46+
47+
// Generate resource descriptor
48+
CUdevResourceDesc desc;
49+
C10_CUDA_DRIVER_CHECK(
50+
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
51+
&desc, result_data, 1));
2952

30-
CUdevice device;
31-
device_id_ = device_id;
32-
C10_CUDA_DRIVER_CHECK(
33-
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
34-
35-
// Get device resources
36-
CUdevResource device_resource;
37-
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
38-
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
39-
40-
// Split resources
41-
std::vector<CUdevResource> result(1);
42-
auto result_data = result.data();
43-
unsigned int nb_groups = 1;
44-
CUdevResource remaining;
45-
46-
C10_CUDA_DRIVER_CHECK(
47-
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
48-
result_data,
49-
&nb_groups,
50-
&device_resource,
51-
&remaining,
52-
0, // default flags
53-
num_sms));
54-
55-
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
56-
57-
// Generate resource descriptor
58-
CUdevResourceDesc desc;
59-
C10_CUDA_DRIVER_CHECK(
60-
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
61-
&desc, result_data, 1));
62-
63-
// Create green context
64-
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
65-
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
66-
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
67-
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
68-
69-
// Convert to regular context
70-
C10_CUDA_DRIVER_CHECK(
71-
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
72-
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
53+
// Create green context
54+
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
55+
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
56+
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
57+
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
58+
59+
// Convert to regular context
60+
C10_CUDA_DRIVER_CHECK(
61+
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
62+
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
7363
#else
74-
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
64+
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
7565
#endif
7666
}
7767

7868
std::unique_ptr<GreenContext> GreenContext::create(
7969
uint32_t num_sms,
8070
std::optional<uint32_t> device_id) {
81-
#if HAS_CUDA_GREEN_CONTEXT()
71+
#if CUDA_HAS_GREEN_CONTEXT
8272
if (!device_id.has_value()) {
8373
device_id = at::cuda::current_device();
8474
}
85-
return std::unique_ptr<GreenContext>(new GreenContext(device_id.value(), num_sms));
75+
return std::make_unique<GreenContext>(device_id.value(), num_sms);
8676
#else
8777
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
8878
#endif
8979
}
9080

9181
// Implement move operations
9282
GreenContext::GreenContext(GreenContext&& other) noexcept{
93-
#if HAS_CUDA_GREEN_CONTEXT()
83+
#if CUDA_HAS_GREEN_CONTEXT
9484
device_id_ = std::exchange(other.device_id_, -1);
9585
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
9686
context_ = std::exchange(other.context_, nullptr);
@@ -101,7 +91,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
10191
}
10292

10393
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
104-
#if HAS_CUDA_GREEN_CONTEXT()
94+
#if CUDA_HAS_GREEN_CONTEXT
10595
if (this != &other) {
10696
// Clean up current resources
10797
if (green_ctx_) {
@@ -130,17 +120,33 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
130120
}
131121

132122
GreenContext::~GreenContext() noexcept{
133-
#if HAS_CUDA_GREEN_CONTEXT()
123+
#if CUDA_HAS_GREEN_CONTEXT
134124
C10_CUDA_DRIVER_CHECK(
135125
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
136126
#else
137127
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
138128
#endif
139129
}
140130

131+
// Get the underlying CUDA context
132+
CUcontext GreenContext::getContext() const {
133+
#if CUDA_HAS_GREEN_CONTEXT
134+
return context_;
135+
#else
136+
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
137+
#endif
138+
}
139+
140+
// Get the underlying green context
141+
#if CUDA_HAS_GREEN_CONTEXT
142+
CUgreenCtx GreenContext::getGreenContext() const {
143+
return green_ctx_;
144+
}
145+
#endif
146+
141147
// Make this context current
142148
void GreenContext::setContext() {
143-
#if HAS_CUDA_GREEN_CONTEXT()
149+
#if CUDA_HAS_GREEN_CONTEXT
144150
auto current_stream = c10::cuda::getCurrentCUDAStream();
145151
parent_stream_ = current_stream.stream();
146152

@@ -169,7 +175,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
169175
}
170176

171177
void GreenContext::popContext() {
172-
#if HAS_CUDA_GREEN_CONTEXT()
178+
#if CUDA_HAS_GREEN_CONTEXT
173179
// see above note about stream being hardcoded to the default stream
174180
at::cuda::CUDAEvent ev;
175181
ev.record(c10::cuda::getCurrentCUDAStream());
Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,53 @@
11
#pragma once
22
#include <ATen/cuda/CUDAEvent.h>
3-
#include <cuda.h>
43

5-
// Forward declare green context as opaque ptr
6-
typedef struct CUgreenCtx_st* CUgreenCtx;
4+
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
5+
#include <c10/cuda/driver_api.h>
6+
#include <cuda.h>
7+
#include <memory>
8+
#include <stdexcept>
9+
#include <vector>
10+
#define CUDA_HAS_GREEN_CONTEXT 1
11+
#else
12+
#define CUDA_HAS_GREEN_CONTEXT 0
13+
#endif
714

815
namespace at::cuda {
916

1017
class TORCH_CUDA_CPP_API GreenContext {
1118
public:
12-
// Green context creation
13-
static std::unique_ptr<GreenContext> create(
14-
uint32_t num_sms,
15-
std::optional<uint32_t> device_id);
16-
~GreenContext() noexcept;
19+
GreenContext(uint32_t device_id, uint32_t num_sms);
20+
21+
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
1722

1823
// Delete copy constructor and assignment
1924
GreenContext(const GreenContext&) = delete;
2025
GreenContext& operator=(const GreenContext&) = delete;
2126

27+
// Implement move operations
28+
GreenContext(GreenContext&& other) noexcept;
29+
GreenContext& operator=(GreenContext&& other) noexcept;
30+
~GreenContext() noexcept;
31+
32+
// Get the underlying CUDA context
33+
CUcontext getContext() const;
34+
35+
// Get the underlying green context
36+
#if CUDA_HAS_GREEN_CONTEXT
37+
CUgreenCtx getGreenContext() const;
38+
#endif
39+
2240
// Make this context current
2341
void setContext();
2442

2543
void popContext();
2644

2745
private:
28-
GreenContext(uint32_t device_id, uint32_t num_sms);
29-
// Implement move operations
30-
GreenContext(GreenContext&& other) noexcept;
31-
GreenContext& operator=(GreenContext&& other) noexcept;
32-
46+
#if CUDA_HAS_GREEN_CONTEXT
3347
int32_t device_id_ = -1;
3448
CUgreenCtx green_ctx_ = nullptr;
3549
CUcontext context_ = nullptr;
3650
cudaStream_t parent_stream_ = nullptr;
51+
#endif
3752
};
3853
} // namespace at::cuda

0 commit comments

Comments
 (0)