@@ -18,6 +18,7 @@ struct libcu_ops {
18
18
CUresult (*cuCtxCreate)(CUcontext *pctx, unsigned int flags, CUdevice dev);
19
19
CUresult (*cuCtxDestroy)(CUcontext ctx);
20
20
CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
21
+ CUresult (*cuCtxSetCurrent)(CUcontext ctx);
21
22
CUresult (*cuDeviceGet)(CUdevice *device, int ordinal);
22
23
CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t size);
23
24
CUresult (*cuMemFree)(CUdeviceptr dptr);
@@ -34,6 +35,7 @@ struct libcu_ops {
34
35
CUpointer_attribute *attributes,
35
36
void **data, CUdeviceptr ptr);
36
37
CUresult (*cuStreamSynchronize)(CUstream hStream);
38
+ CUresult (*cuCtxSynchronize)(void );
37
39
} libcu_ops;
38
40
39
41
#if USE_DLOPEN
@@ -48,7 +50,7 @@ struct DlHandleCloser {
48
50
std::unique_ptr<void , DlHandleCloser> cuDlHandle = nullptr ;
49
51
int InitCUDAOps () {
50
52
#ifdef _WIN32
51
- const char *lib_name = " cudart .dll" ;
53
+ const char *lib_name = " nvcuda .dll" ;
52
54
#else
53
55
const char *lib_name = " libcuda.so" ;
54
56
#endif
@@ -84,6 +86,12 @@ int InitCUDAOps() {
84
86
fprintf (stderr, " cuCtxGetCurrent symbol not found in %s\n " , lib_name);
85
87
return -1 ;
86
88
}
89
+ *(void **)&libcu_ops.cuCtxSetCurrent =
90
+ utils_get_symbol_addr (cuDlHandle.get (), " cuCtxSetCurrent" , lib_name);
91
+ if (libcu_ops.cuCtxSetCurrent == nullptr ) {
92
+ fprintf (stderr, " cuCtxSetCurrent symbol not found in %s\n " , lib_name);
93
+ return -1 ;
94
+ }
87
95
*(void **)&libcu_ops.cuDeviceGet =
88
96
utils_get_symbol_addr (cuDlHandle.get (), " cuDeviceGet" , lib_name);
89
97
if (libcu_ops.cuDeviceGet == nullptr ) {
@@ -153,6 +161,12 @@ int InitCUDAOps() {
153
161
lib_name);
154
162
return -1 ;
155
163
}
164
+ *(void **)&libcu_ops.cuCtxSynchronize =
165
+ utils_get_symbol_addr (cuDlHandle.get (), " cuCtxSynchronize" , lib_name);
166
+ if (libcu_ops.cuCtxSynchronize == nullptr ) {
167
+ fprintf (stderr, " cuCtxSynchronize symbol not found in %s\n " , lib_name);
168
+ return -1 ;
169
+ }
156
170
157
171
return 0 ;
158
172
}
@@ -165,6 +179,7 @@ int InitCUDAOps() {
165
179
libcu_ops.cuCtxCreate = cuCtxCreate;
166
180
libcu_ops.cuCtxDestroy = cuCtxDestroy;
167
181
libcu_ops.cuCtxGetCurrent = cuCtxGetCurrent;
182
+ libcu_ops.cuCtxSetCurrent = cuCtxSetCurrent;
168
183
libcu_ops.cuDeviceGet = cuDeviceGet;
169
184
libcu_ops.cuMemAlloc = cuMemAlloc;
170
185
libcu_ops.cuMemAllocHost = cuMemAllocHost;
@@ -176,11 +191,31 @@ int InitCUDAOps() {
176
191
libcu_ops.cuPointerGetAttribute = cuPointerGetAttribute;
177
192
libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
178
193
libcu_ops.cuStreamSynchronize = cuStreamSynchronize;
194
+ libcu_ops.cuCtxSynchronize = cuCtxSynchronize;
179
195
180
196
return 0 ;
181
197
}
182
198
#endif // USE_DLOPEN
183
199
200
+ static CUresult set_context (CUcontext required_ctx, CUcontext *restore_ctx) {
201
+ CUcontext current_ctx = NULL ;
202
+ CUresult cu_result = libcu_ops.cuCtxGetCurrent (¤t_ctx);
203
+ if (cu_result != CUDA_SUCCESS) {
204
+ fprintf (stderr, " cuCtxGetCurrent() failed.\n " );
205
+ return cu_result;
206
+ }
207
+
208
+ *restore_ctx = current_ctx;
209
+ if (current_ctx != required_ctx) {
210
+ cu_result = libcu_ops.cuCtxSetCurrent (required_ctx);
211
+ if (cu_result != CUDA_SUCCESS) {
212
+ fprintf (stderr, " cuCtxSetCurrent() failed.\n " );
213
+ }
214
+ }
215
+
216
+ return cu_result;
217
+ }
218
+
184
219
static int init_cuda_lib (void ) {
185
220
CUresult result = libcu_ops.cuInit (0 );
186
221
if (result != CUDA_SUCCESS) {
@@ -191,8 +226,6 @@ static int init_cuda_lib(void) {
191
226
192
227
int cuda_fill (CUcontext context, CUdevice device, void *ptr, size_t size,
193
228
const void *pattern, size_t pattern_size) {
194
-
195
- (void )context;
196
229
(void )device;
197
230
(void )pattern_size;
198
231
@@ -202,23 +235,40 @@ int cuda_fill(CUcontext context, CUdevice device, void *ptr, size_t size,
202
235
return -1 ;
203
236
}
204
237
238
+ // set required context
239
+ CUcontext curr_context = nullptr ;
240
+ set_context (context, &curr_context);
241
+
205
242
int ret = 0 ;
206
243
CUresult res =
207
244
libcu_ops.cuMemsetD32 ((CUdeviceptr)ptr, *(unsigned int *)pattern,
208
245
size / sizeof (unsigned int ));
209
246
if (res != CUDA_SUCCESS) {
210
- fprintf (stderr, " cuMemsetD32() failed!\n " );
247
+ fprintf (stderr, " cuMemsetD32(%llu, %u, %zu) failed!\n " ,
248
+ (CUdeviceptr)ptr, *(unsigned int *)pattern,
249
+ size / pattern_size);
250
+ return -1 ;
251
+ }
252
+
253
+ res = libcu_ops.cuCtxSynchronize ();
254
+ if (res != CUDA_SUCCESS) {
255
+ fprintf (stderr, " cuCtxSynchronize() failed!\n " );
211
256
return -1 ;
212
257
}
213
258
259
+ // restore context
260
+ set_context (curr_context, &curr_context);
214
261
return ret;
215
262
}
216
263
217
- int cuda_copy (CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
218
- size_t size) {
219
- (void )context;
264
+ int cuda_copy (CUcontext context, CUdevice device, void *dst_ptr,
265
+ const void *src_ptr, size_t size) {
220
266
(void )device;
221
267
268
+ // set required context
269
+ CUcontext curr_context = nullptr ;
270
+ set_context (context, &curr_context);
271
+
222
272
int ret = 0 ;
223
273
CUresult res =
224
274
libcu_ops.cuMemcpy ((CUdeviceptr)dst_ptr, (CUdeviceptr)src_ptr, size);
@@ -227,12 +277,14 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
227
277
return -1 ;
228
278
}
229
279
230
- res = libcu_ops.cuStreamSynchronize ( 0 );
280
+ res = libcu_ops.cuCtxSynchronize ( );
231
281
if (res != CUDA_SUCCESS) {
232
- fprintf (stderr, " cuStreamSynchronize () failed!\n " );
282
+ fprintf (stderr, " cuCtxSynchronize () failed!\n " );
233
283
return -1 ;
234
284
}
235
285
286
+ // restore context
287
+ set_context (curr_context, &curr_context);
236
288
return ret;
237
289
}
238
290
0 commit comments