@@ -18,6 +18,7 @@ struct libcu_ops {
1818 CUresult (*cuCtxCreate)(CUcontext *pctx, unsigned int flags, CUdevice dev);
1919 CUresult (*cuCtxDestroy)(CUcontext ctx);
2020 CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
21+ CUresult (*cuCtxSetCurrent)(CUcontext ctx);
2122 CUresult (*cuDeviceGet)(CUdevice *device, int ordinal);
2223 CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t size);
2324 CUresult (*cuMemFree)(CUdeviceptr dptr);
@@ -34,6 +35,7 @@ struct libcu_ops {
3435 CUpointer_attribute *attributes,
3536 void **data, CUdeviceptr ptr);
3637 CUresult (*cuStreamSynchronize)(CUstream hStream);
38+ CUresult (*cuCtxSynchronize)(void );
3739} libcu_ops;
3840
3941#if USE_DLOPEN
@@ -48,7 +50,7 @@ struct DlHandleCloser {
4850std::unique_ptr<void , DlHandleCloser> cuDlHandle = nullptr ;
4951int InitCUDAOps () {
5052#ifdef _WIN32
51- const char *lib_name = " cudart .dll" ;
53+ const char *lib_name = " nvcuda .dll" ;
5254#else
5355 const char *lib_name = " libcuda.so" ;
5456#endif
@@ -84,6 +86,12 @@ int InitCUDAOps() {
8486 fprintf (stderr, " cuCtxGetCurrent symbol not found in %s\n " , lib_name);
8587 return -1 ;
8688 }
89+ *(void **)&libcu_ops.cuCtxSetCurrent =
90+ utils_get_symbol_addr (cuDlHandle.get (), " cuCtxSetCurrent" , lib_name);
91+ if (libcu_ops.cuCtxSetCurrent == nullptr ) {
92+ fprintf (stderr, " cuCtxSetCurrent symbol not found in %s\n " , lib_name);
93+ return -1 ;
94+ }
8795 *(void **)&libcu_ops.cuDeviceGet =
8896 utils_get_symbol_addr (cuDlHandle.get (), " cuDeviceGet" , lib_name);
8997 if (libcu_ops.cuDeviceGet == nullptr ) {
@@ -153,6 +161,12 @@ int InitCUDAOps() {
153161 lib_name);
154162 return -1 ;
155163 }
164+ *(void **)&libcu_ops.cuCtxSynchronize =
165+ utils_get_symbol_addr (cuDlHandle.get (), " cuCtxSynchronize" , lib_name);
166+ if (libcu_ops.cuCtxSynchronize == nullptr ) {
167+ fprintf (stderr, " cuCtxSynchronize symbol not found in %s\n " , lib_name);
168+ return -1 ;
169+ }
156170
157171 return 0 ;
158172}
@@ -165,6 +179,7 @@ int InitCUDAOps() {
165179 libcu_ops.cuCtxCreate = cuCtxCreate;
166180 libcu_ops.cuCtxDestroy = cuCtxDestroy;
167181 libcu_ops.cuCtxGetCurrent = cuCtxGetCurrent;
182+ libcu_ops.cuCtxSetCurrent = cuCtxSetCurrent;
168183 libcu_ops.cuDeviceGet = cuDeviceGet;
169184 libcu_ops.cuMemAlloc = cuMemAlloc;
170185 libcu_ops.cuMemAllocHost = cuMemAllocHost;
@@ -176,6 +191,7 @@ int InitCUDAOps() {
176191 libcu_ops.cuPointerGetAttribute = cuPointerGetAttribute;
177192 libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
178193 libcu_ops.cuStreamSynchronize = cuStreamSynchronize;
194+ libcu_ops.cuCtxSynchronize = cuCtxSynchronize;
179195
180196 return 0 ;
181197}
@@ -191,8 +207,6 @@ static int init_cuda_lib(void) {
191207
192208int cuda_fill (CUcontext context, CUdevice device, void *ptr, size_t size,
193209 const void *pattern, size_t pattern_size) {
194-
195- (void )context;
196210 (void )device;
197211 (void )pattern_size;
198212
@@ -202,23 +216,40 @@ int cuda_fill(CUcontext context, CUdevice device, void *ptr, size_t size,
202216 return -1 ;
203217 }
204218
219+ // set required context
220+ CUcontext curr_context = nullptr ;
221+ set_context (context, &curr_context);
222+
205223 int ret = 0 ;
206224 CUresult res =
207225 libcu_ops.cuMemsetD32 ((CUdeviceptr)ptr, *(unsigned int *)pattern,
208226 size / sizeof (unsigned int ));
209227 if (res != CUDA_SUCCESS) {
210- fprintf (stderr, " cuMemsetD32() failed!\n " );
228+ fprintf (stderr, " cuMemsetD32(%llu, %u, %zu) failed!\n " ,
229+ (CUdeviceptr)ptr, *(unsigned int *)pattern,
230+ size / pattern_size);
231+ return -1 ;
232+ }
233+
234+ res = libcu_ops.cuCtxSynchronize ();
235+ if (res != CUDA_SUCCESS) {
236+ fprintf (stderr, " cuCtxSynchronize() failed!\n " );
211237 return -1 ;
212238 }
213239
240+ // restore context
241+ set_context (curr_context, &curr_context);
214242 return ret;
215243}
216244
217- int cuda_copy (CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
218- size_t size) {
219- (void )context;
245+ int cuda_copy (CUcontext context, CUdevice device, void *dst_ptr,
246+ const void *src_ptr, size_t size) {
220247 (void )device;
221248
249+ // set required context
250+ CUcontext curr_context = nullptr ;
251+ set_context (context, &curr_context);
252+
222253 int ret = 0 ;
223254 CUresult res =
224255 libcu_ops.cuMemcpy ((CUdeviceptr)dst_ptr, (CUdeviceptr)src_ptr, size);
@@ -227,12 +258,14 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
227258 return -1 ;
228259 }
229260
230- res = libcu_ops.cuStreamSynchronize ( 0 );
261+ res = libcu_ops.cuCtxSynchronize ( );
231262 if (res != CUDA_SUCCESS) {
232- fprintf (stderr, " cuStreamSynchronize () failed!\n " );
263+ fprintf (stderr, " cuCtxSynchronize () failed!\n " );
233264 return -1 ;
234265 }
235266
267+ // restore context
268+ set_context (curr_context, &curr_context);
236269 return ret;
237270}
238271
@@ -287,6 +320,25 @@ CUcontext get_current_context() {
287320 return context;
288321}
289322
323+ static CUresult set_context (CUcontext required_ctx, CUcontext *restore_ctx) {
324+ CUcontext current_ctx = NULL ;
325+ CUresult cu_result = libcu_ops.cuCtxGetCurrent (¤t_ctx);
326+ if (cu_result != CUDA_SUCCESS) {
327+ fprintf (stderr, " cuCtxGetCurrent() failed.\n " );
328+ return cu_result;
329+ }
330+
331+ *restore_ctx = current_ctx;
332+ if (current_ctx != required_ctx) {
333+ cu_result = libcu_ops.cuCtxSetCurrent (required_ctx);
334+ if (cu_result != CUDA_SUCCESS) {
335+ fprintf (stderr, " cuCtxSetCurrent() failed.\n " );
336+ }
337+ }
338+
339+ return cu_result;
340+ }
341+
290342UTIL_ONCE_FLAG cuda_init_flag;
291343int InitResult;
292344void init_cuda_once () {
0 commit comments