@@ -18,6 +18,7 @@ struct libcu_ops {
1818 CUresult (*cuCtxCreate)(CUcontext *pctx, unsigned int flags, CUdevice dev);
1919 CUresult (*cuCtxDestroy)(CUcontext ctx);
2020 CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
21+ CUresult (*cuCtxSetCurrent)(CUcontext ctx);
2122 CUresult (*cuDeviceGet)(CUdevice *device, int ordinal);
2223 CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t size);
2324 CUresult (*cuMemFree)(CUdeviceptr dptr);
@@ -34,6 +35,7 @@ struct libcu_ops {
3435 CUpointer_attribute *attributes,
3536 void **data, CUdeviceptr ptr);
3637 CUresult (*cuStreamSynchronize)(CUstream hStream);
38+ CUresult (*cuCtxSynchronize)(void );
3739} libcu_ops;
3840
3941#if USE_DLOPEN
@@ -48,7 +50,7 @@ struct DlHandleCloser {
4850std::unique_ptr<void , DlHandleCloser> cuDlHandle = nullptr ;
4951int InitCUDAOps () {
5052#ifdef _WIN32
51- const char *lib_name = " cudart .dll" ;
53+ const char *lib_name = " nvcuda .dll" ;
5254#else
5355 const char *lib_name = " libcuda.so" ;
5456#endif
@@ -84,6 +86,12 @@ int InitCUDAOps() {
8486 fprintf (stderr, " cuCtxGetCurrent symbol not found in %s\n " , lib_name);
8587 return -1 ;
8688 }
89+ *(void **)&libcu_ops.cuCtxSetCurrent =
90+ utils_get_symbol_addr (cuDlHandle.get (), " cuCtxSetCurrent" , lib_name);
91+ if (libcu_ops.cuCtxSetCurrent == nullptr ) {
92+ fprintf (stderr, " cuCtxSetCurrent symbol not found in %s\n " , lib_name);
93+ return -1 ;
94+ }
8795 *(void **)&libcu_ops.cuDeviceGet =
8896 utils_get_symbol_addr (cuDlHandle.get (), " cuDeviceGet" , lib_name);
8997 if (libcu_ops.cuDeviceGet == nullptr ) {
@@ -153,6 +161,13 @@ int InitCUDAOps() {
153161 lib_name);
154162 return -1 ;
155163 }
164+ *(void **)&libcu_ops.cuCtxSynchronize = utils_get_symbol_addr (
165+ cuDlHandle.get (), " cuCtxSynchronize" , lib_name);
166+ if (libcu_ops.cuCtxSynchronize == nullptr ) {
167+ fprintf (stderr, " cuCtxSynchronize symbol not found in %s\n " ,
168+ lib_name);
169+ return -1 ;
170+ }
156171
157172 return 0 ;
158173}
@@ -165,6 +180,7 @@ int InitCUDAOps() {
165180 libcu_ops.cuCtxCreate = cuCtxCreate;
166181 libcu_ops.cuCtxDestroy = cuCtxDestroy;
167182 libcu_ops.cuCtxGetCurrent = cuCtxGetCurrent;
183+ libcu_ops.cuCtxSetCurrent = cuCtxSetCurrent;
168184 libcu_ops.cuDeviceGet = cuDeviceGet;
169185 libcu_ops.cuMemAlloc = cuMemAlloc;
170186 libcu_ops.cuMemAllocHost = cuMemAllocHost;
@@ -176,6 +192,7 @@ int InitCUDAOps() {
176192 libcu_ops.cuPointerGetAttribute = cuPointerGetAttribute;
177193 libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
178194 libcu_ops.cuStreamSynchronize = cuStreamSynchronize;
195+ libcu_ops.cuCtxSynchronize = cuCtxSynchronize;
179196
180197 return 0 ;
181198}
@@ -191,8 +208,6 @@ static int init_cuda_lib(void) {
191208
192209int cuda_fill (CUcontext context, CUdevice device, void *ptr, size_t size,
193210 const void *pattern, size_t pattern_size) {
194-
195- (void )context;
196211 (void )device;
197212 (void )pattern_size;
198213
@@ -202,23 +217,40 @@ int cuda_fill(CUcontext context, CUdevice device, void *ptr, size_t size,
202217 return -1 ;
203218 }
204219
220+ // set required context
221+ CUcontext curr_context = nullptr ;
222+ set_context (context, &curr_context);
223+
205224 int ret = 0 ;
206225 CUresult res =
207226 libcu_ops.cuMemsetD32 ((CUdeviceptr)ptr, *(unsigned int *)pattern,
208227 size / sizeof (unsigned int ));
209228 if (res != CUDA_SUCCESS) {
210- fprintf (stderr, " cuMemsetD32() failed!\n " );
229+ fprintf (stderr, " cuMemsetD32(%llu, %u, %zu) failed!\n " ,
230+ (CUdeviceptr)ptr, *(unsigned int *)pattern,
231+ size / pattern_size);
211232 return -1 ;
212233 }
213234
235+ res = libcu_ops.cuCtxSynchronize ();
236+ if (res != CUDA_SUCCESS) {
237+ fprintf (stderr, " cuCtxSynchronize() failed!\n " );
238+ return -1 ;
239+ }
240+
241+ // restore context
242+ set_context (curr_context, &curr_context);
214243 return ret;
215244}
216245
217- int cuda_copy (CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
218- size_t size) {
219- (void )context;
246+ int cuda_copy (CUcontext context, CUdevice device, void *dst_ptr,
247+ const void *src_ptr, size_t size) {
220248 (void )device;
221249
250+ // set required context
251+ CUcontext curr_context = nullptr ;
252+ set_context (context, &curr_context);
253+
222254 int ret = 0 ;
223255 CUresult res =
224256 libcu_ops.cuMemcpy ((CUdeviceptr)dst_ptr, (CUdeviceptr)src_ptr, size);
@@ -227,12 +259,14 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
227259 return -1 ;
228260 }
229261
230- res = libcu_ops.cuStreamSynchronize ( 0 );
262+ res = libcu_ops.cuCtxSynchronize ( );
231263 if (res != CUDA_SUCCESS) {
232- fprintf (stderr, " cuStreamSynchronize () failed!\n " );
264+ fprintf (stderr, " cuCtxSynchronize () failed!\n " );
233265 return -1 ;
234266 }
235267
268+ // restore context
269+ set_context (curr_context, &curr_context);
236270 return ret;
237271}
238272
@@ -287,6 +321,25 @@ CUcontext get_current_context() {
287321 return context;
288322}
289323
324+ CUresult set_context (CUcontext required_ctx, CUcontext *restore_ctx) {
325+ CUcontext current_ctx = NULL ;
326+ CUresult cu_result = libcu_ops.cuCtxGetCurrent (¤t_ctx);
327+ if (cu_result != CUDA_SUCCESS) {
328+ fprintf (stderr, " cuCtxGetCurrent() failed.\n " );
329+ return cu_result;
330+ }
331+
332+ *restore_ctx = current_ctx;
333+ if (current_ctx != required_ctx) {
334+ cu_result = libcu_ops.cuCtxSetCurrent (required_ctx);
335+ if (cu_result != CUDA_SUCCESS) {
336+ fprintf (stderr, " cuCtxSetCurrent() failed.\n " );
337+ }
338+ }
339+
340+ return cu_result;
341+ }
342+
290343UTIL_ONCE_FLAG cuda_init_flag;
291344int InitResult;
292345void init_cuda_once () {
0 commit comments