Skip to content

Commit 02d0bb9

Browse files
committed
Make KernelDescriptor bind to the current CUDA context.
A `KernelDescriptor` will now store a reference to the current CUDA context and only compare equal to another `KernelDescriptor` if they have a matching CUDA context. This prevents bugs when dealing with multiple CUDA contexts and switching between these. Previously, it was assume that there was only a single CUDA context.
1 parent 7c8a958 commit 02d0bb9

File tree

5 files changed

+179
-34
lines changed

5 files changed

+179
-34
lines changed

include/kernel_launcher/arg.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <cuda.h>
55

6+
#include <array>
67
#include <cstring>
78
#include <iostream>
89
#include <utility>

include/kernel_launcher/cuda.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ struct CudaDevice {
192192
*/
193193
struct CudaContextHandle {
194194
CudaContextHandle() = default;
195-
CudaContextHandle(CUcontext c) : context_(c) {};
195+
CudaContextHandle(CUcontext c);
196196

197197
/**
198198
* Returns the current CUDA context or throws an error if CUDA has not
@@ -205,17 +205,25 @@ struct CudaContextHandle {
205205
*/
206206
CudaDevice device() const;
207207

208-
void with(std::function<void()> f) const;
209-
210208
/**
211209
* Returns the underlying `CUcontext`.
212210
*/
213211
CUcontext get() const {
214212
return context_;
215213
}
216214

215+
bool operator==(const CudaContextHandle& that) const;
216+
bool operator!=(const CudaContextHandle& that) const;
217+
217218
private:
218219
CUcontext context_ = nullptr;
220+
unsigned long long id_ = ~0ULL;
221+
};
222+
223+
struct CudaContextGuard {
224+
CudaContextGuard(CudaContextHandle ctx) : CudaContextGuard(ctx.get()) {}
225+
CudaContextGuard(CUcontext ctx);
226+
~CudaContextGuard();
219227
};
220228

221229
/**

include/kernel_launcher/registry.h

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,28 @@ struct KernelDescriptor {
5353
KernelDescriptor(KernelDescriptor&) noexcept = default;
5454
KernelDescriptor(const KernelDescriptor&) = default;
5555

56-
template<typename D>
57-
KernelDescriptor(D&& descriptor) {
58-
using T = typename std::decay<D>::type;
59-
descriptor_type_ = type_of<T>();
60-
descriptor_ = std::make_shared<T>(std::forward<D>(descriptor));
61-
hash_ = hash_fields(descriptor_type_, descriptor_->hash());
56+
KernelDescriptor(
57+
std::shared_ptr<IKernelDescriptor> descriptor,
58+
CudaContextHandle ctx = CudaContextHandle::current()) :
59+
ctx_(ctx),
60+
descriptor_(std::move(descriptor)) {
61+
const IKernelDescriptor& inner = *descriptor_;
62+
hash_ = hash_fields(
63+
typeid(inner).hash_code(),
64+
descriptor_->hash(),
65+
ctx_.get());
6266
}
6367

68+
template<typename D>
69+
KernelDescriptor(std::shared_ptr<D> descriptor) :
70+
KernelDescriptor(
71+
std::shared_ptr<IKernelDescriptor>(std::move(descriptor))) {}
72+
73+
template<typename D>
74+
KernelDescriptor(D&& descriptor) :
75+
KernelDescriptor(
76+
std::make_shared<std::decay_t<D>>(std::forward<D>(descriptor))) {}
77+
6478
const IKernelDescriptor& get() const {
6579
return *descriptor_;
6680
}
@@ -70,7 +84,7 @@ struct KernelDescriptor {
7084
}
7185

7286
bool operator==(const KernelDescriptor& that) const {
73-
return that.hash_ == hash_ && that.descriptor_type_ == descriptor_type_
87+
return that.hash_ == hash_ && that.ctx_ == ctx_
7488
&& that.descriptor_->equals(*descriptor_);
7589
}
7690

@@ -80,7 +94,7 @@ struct KernelDescriptor {
8094

8195
private:
8296
hash_t hash_;
83-
TypeInfo descriptor_type_;
97+
CudaContextHandle ctx_;
8498
std::shared_ptr<IKernelDescriptor> descriptor_;
8599
};
86100
} // namespace kernel_launcher

src/cuda.cpp

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,23 @@
88

99
namespace kernel_launcher {
1010

11-
void cuda_check(CUresult result, const char* msg) {
12-
if (result != CUDA_SUCCESS) {
13-
const char* name = "???";
14-
const char* description = "???";
11+
CudaException build_cuda_exception(CUresult& result, const char* msg) {
12+
const char* name = "???";
13+
const char* description = "???";
1514

16-
// Ignore error since we are already handling another error
17-
cuGetErrorName(result, &name);
18-
cuGetErrorString(result, &description);
15+
// Ignore error since we are already handling another error
16+
cuGetErrorName(result, &name);
17+
cuGetErrorString(result, &description);
18+
19+
std::stringstream display;
20+
display << "CUDA error: " << name << " (" << description << "): " << msg;
21+
auto e = CudaException(result, display.str());
22+
return e;
23+
}
1924

20-
std::stringstream display;
21-
display << "CUDA error: " << name << " (" << description
22-
<< "): " << msg;
23-
throw CudaException(result, display.str());
25+
void cuda_check(CUresult result, const char* msg) {
26+
if (result != CUDA_SUCCESS) {
27+
throw build_cuda_exception(result, msg);
2428
}
2529
}
2630

@@ -129,7 +133,15 @@ std::string CudaDevice::uuid() const {
129133
CudaArch CudaDevice::arch() const {
130134
int minor = attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR);
131135
int major = attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR);
132-
return CudaArch(major, minor);
136+
return {major, minor};
137+
}
138+
139+
CudaContextHandle::CudaContextHandle(CUcontext c) {
140+
context_ = c;
141+
142+
#if CUDA_VERSION >= 12000
143+
KERNEL_LAUNCHER_CUDA_CHECK(cuCtxGetId(context_, &id_));
144+
#endif
133145
}
134146

135147
CudaContextHandle CudaContextHandle::current() {
@@ -146,21 +158,27 @@ CudaContextHandle CudaContextHandle::current() {
146158
}
147159

148160
CudaDevice CudaContextHandle::device() const {
161+
CudaContextGuard guard {context_};
149162
CUdevice d = -1;
150-
with([&]() { KERNEL_LAUNCHER_CUDA_CHECK(cuCtxGetDevice(&d)); });
163+
KERNEL_LAUNCHER_CUDA_CHECK(cuCtxGetDevice(&d));
151164
return CudaDevice(d);
152165
}
153166

154-
void CudaContextHandle::with(std::function<void()> f) const {
155-
KERNEL_LAUNCHER_CUDA_CHECK(cuCtxPushCurrent(context_));
156-
try {
157-
f();
158-
KERNEL_LAUNCHER_CUDA_CHECK(cuCtxPopCurrent(nullptr));
159-
} catch (...) {
160-
// Ignore errors. There is not much we can do at this point.
161-
cuCtxPopCurrent(nullptr);
162-
throw;
163-
}
167+
bool CudaContextHandle::operator==(const CudaContextHandle& that) const {
168+
return id_ == that.id_ && context_ == that.context_;
169+
}
170+
171+
bool CudaContextHandle::operator!=(const CudaContextHandle& that) const {
172+
return !(*this == that);
173+
}
174+
175+
CudaContextGuard::CudaContextGuard(CUcontext ctx) {
176+
KERNEL_LAUNCHER_CUDA_CHECK(cuCtxPushCurrent(ctx));
177+
}
178+
179+
CudaContextGuard::~CudaContextGuard() {
180+
CUcontext current;
181+
KERNEL_LAUNCHER_CUDA_CHECK(cuCtxPopCurrent(&current));
164182
}
165183

166184
void cuda_raw_copy(const void* src, void* dst, size_t num_bytes) {

tests/registry.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,110 @@ struct VectorAddDescriptor: IKernelDescriptor {
1515
}
1616
};
1717

18+
struct MatrixMulDescriptor: IKernelDescriptor {
19+
MatrixMulDescriptor(int size) : size_(size) {}
20+
21+
KernelBuilder build() const override {
22+
return KernelBuilder("matrix_mul", "TODO");
23+
}
24+
25+
bool equals(const IKernelDescriptor& that) const override {
26+
if (auto ptr = dynamic_cast<const MatrixMulDescriptor*>(&that)) {
27+
return ptr->size_ == size_;
28+
} else {
29+
return false;
30+
}
31+
}
32+
33+
hash_t hash() const override {
34+
return size_;
35+
}
36+
37+
int size_;
38+
};
39+
40+
TEST_CASE("KernelDescriptor", "[CUDA]") {
41+
CUcontext ctx, ctx2;
42+
KERNEL_LAUNCHER_CUDA_CHECK(cuInit(0));
43+
44+
KERNEL_LAUNCHER_CUDA_CHECK(cuCtxCreate(&ctx, 0, 0));
45+
auto a = KernelDescriptor(VectorAddDescriptor());
46+
auto b = KernelDescriptor(std::make_shared<MatrixMulDescriptor>(1));
47+
auto c = KernelDescriptor(std::shared_ptr<IKernelDescriptor>(
48+
std::make_shared<MatrixMulDescriptor>(1)));
49+
auto d = KernelDescriptor(MatrixMulDescriptor(2));
50+
51+
// A KernelDescriptor is based on the current CUDA context.
52+
// Creating a new CUDA context here will mean that new descriptors will be
53+
// based on a different CUDA context than before.
54+
KERNEL_LAUNCHER_CUDA_CHECK(cuCtxCreate(&ctx2, 0, 0));
55+
auto e = KernelDescriptor(VectorAddDescriptor());
56+
57+
CHECK(a == a);
58+
CHECK_FALSE(a == b);
59+
CHECK_FALSE(a == c);
60+
CHECK_FALSE(a == d);
61+
CHECK_FALSE(a == e);
62+
63+
CHECK_FALSE(b == a);
64+
CHECK(b == b);
65+
CHECK(b == c);
66+
CHECK_FALSE(b == d);
67+
CHECK_FALSE(b == e);
68+
69+
CHECK_FALSE(c == a);
70+
CHECK(c == b);
71+
CHECK(c == c);
72+
CHECK_FALSE(c == d);
73+
CHECK_FALSE(c == e);
74+
75+
CHECK_FALSE(d == a);
76+
CHECK_FALSE(d == b);
77+
CHECK_FALSE(d == c);
78+
CHECK(d == d);
79+
CHECK_FALSE(d == e);
80+
81+
CHECK_FALSE(e == a);
82+
CHECK_FALSE(e == b);
83+
CHECK_FALSE(e == c);
84+
CHECK_FALSE(e == d);
85+
CHECK(e == e);
86+
87+
// These match the ones above
88+
CHECK(a.hash() == a.hash());
89+
CHECK_FALSE(a.hash() == b.hash());
90+
CHECK_FALSE(a.hash() == c.hash());
91+
CHECK_FALSE(a.hash() == d.hash());
92+
CHECK_FALSE(a.hash() == e.hash());
93+
94+
CHECK_FALSE(b.hash() == a.hash());
95+
CHECK(b.hash() == b.hash());
96+
CHECK(b.hash() == c.hash());
97+
CHECK_FALSE(b.hash() == d.hash());
98+
CHECK_FALSE(b.hash() == e.hash());
99+
100+
CHECK_FALSE(c.hash() == a.hash());
101+
CHECK(c.hash() == b.hash());
102+
CHECK(c.hash() == c.hash());
103+
CHECK_FALSE(c.hash() == d.hash());
104+
CHECK_FALSE(c.hash() == e.hash());
105+
106+
CHECK_FALSE(d.hash() == a.hash());
107+
CHECK_FALSE(d.hash() == b.hash());
108+
CHECK_FALSE(d.hash() == c.hash());
109+
CHECK(d.hash() == d.hash());
110+
CHECK_FALSE(d.hash() == e.hash());
111+
112+
CHECK_FALSE(e.hash() == a.hash());
113+
CHECK_FALSE(e.hash() == b.hash());
114+
CHECK_FALSE(e.hash() == c.hash());
115+
CHECK_FALSE(e.hash() == d.hash());
116+
CHECK(e.hash() == e.hash());
117+
118+
KERNEL_LAUNCHER_CUDA_CHECK(cuCtxDestroy(ctx));
119+
KERNEL_LAUNCHER_CUDA_CHECK(cuCtxDestroy(ctx2));
120+
}
121+
18122
TEST_CASE("KernelRegistry", "[CUDA]") {
19123
CUcontext ctx;
20124
KERNEL_LAUNCHER_CUDA_CHECK(cuInit(0));

0 commit comments

Comments
 (0)