@@ -51,6 +51,8 @@ typedef struct cu_ops_t {
5151
5252 CUresult (* cuGetErrorName )(CUresult error , const char * * pStr );
5353 CUresult (* cuGetErrorString )(CUresult error , const char * * pStr );
54+ CUresult (* cuCtxGetCurrent )(CUcontext * pctx );
55+ CUresult (* cuCtxSetCurrent )(CUcontext ctx );
5456 CUresult (* cuIpcGetMemHandle )(CUipcMemHandle * pHandle , CUdeviceptr dptr );
5557 CUresult (* cuIpcOpenMemHandle )(CUdeviceptr * pdptr , CUipcMemHandle handle ,
5658 unsigned int Flags );
@@ -121,6 +123,10 @@ static void init_cu_global_state(void) {
121123 utils_get_symbol_addr (0 , "cuGetErrorName" , lib_name );
122124 * (void * * )& g_cu_ops .cuGetErrorString =
123125 utils_get_symbol_addr (0 , "cuGetErrorString" , lib_name );
126+ * (void * * )& g_cu_ops .cuCtxGetCurrent =
127+ utils_get_symbol_addr (0 , "cuCtxGetCurrent" , lib_name );
128+ * (void * * )& g_cu_ops .cuCtxSetCurrent =
129+ utils_get_symbol_addr (0 , "cuCtxSetCurrent" , lib_name );
124130 * (void * * )& g_cu_ops .cuIpcGetMemHandle =
125131 utils_get_symbol_addr (0 , "cuIpcGetMemHandle" , lib_name );
126132 * (void * * )& g_cu_ops .cuIpcOpenMemHandle =
@@ -132,6 +138,7 @@ static void init_cu_global_state(void) {
132138 !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
133139 !g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
134140 !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
141+ !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ||
135142 !g_cu_ops .cuIpcGetMemHandle || !g_cu_ops .cuIpcOpenMemHandle ||
136143 !g_cu_ops .cuIpcCloseMemHandle ) {
137144 LOG_ERR ("Required CUDA symbols not found." );
@@ -202,6 +209,31 @@ static void cu_memory_provider_finalize(void *provider) {
202209 umf_ba_global_free (provider );
203210}
204211
212+ /*
213+ * This function is used by the CUDA provider to make sure that
214+ * the required context is set. If the current context is
215+ * not the required one, it will be saved in restore_ctx.
216+ */
217+ static inline umf_result_t set_context (CUcontext required_ctx ,
218+ CUcontext * restore_ctx ) {
219+ CUcontext current_ctx = NULL ;
220+ CUresult cu_result = g_cu_ops .cuCtxGetCurrent (& current_ctx );
221+ if (cu_result != CUDA_SUCCESS ) {
222+ LOG_ERR ("cuCtxGetCurrent() failed." );
223+ return cu2umf_result (cu_result );
224+ }
225+ * restore_ctx = current_ctx ;
226+ if (current_ctx != required_ctx ) {
227+ cu_result = g_cu_ops .cuCtxSetCurrent (required_ctx );
228+ if (cu_result != CUDA_SUCCESS ) {
229+ LOG_ERR ("cuCtxSetCurrent() failed." );
230+ return cu2umf_result (cu_result );
231+ }
232+ }
233+
234+ return UMF_RESULT_SUCCESS ;
235+ }
236+
205237static umf_result_t cu_memory_provider_alloc (void * provider , size_t size ,
206238 size_t alignment ,
207239 void * * resultPtr ) {
@@ -217,6 +249,13 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
217249 return UMF_RESULT_ERROR_NOT_SUPPORTED ;
218250 }
219251
252+ // Remember current context and set the one from the provider
253+ CUcontext restore_ctx = NULL ;
254+ umf_result_t umf_result = set_context (cu_provider -> context , & restore_ctx );
255+ if (umf_result != UMF_RESULT_SUCCESS ) {
256+ return umf_result ;
257+ }
258+
220259 CUresult cu_result = CUDA_SUCCESS ;
221260 switch (cu_provider -> memory_type ) {
222261 case UMF_MEMORY_TYPE_HOST : {
@@ -236,16 +275,21 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
236275 // this shouldn't happen as we check the memory_type settings during
237276 // the initialization
238277 LOG_ERR ("unsupported USM memory type" );
239- return UMF_RESULT_ERROR_UNKNOWN ;
278+ assert (false) ;
240279 }
241280
242281 // check the alignment
243282 if (alignment > 0 && ((uintptr_t )(* resultPtr ) % alignment ) != 0 ) {
244283 cu_memory_provider_free (provider , * resultPtr , size );
245284 LOG_ERR ("unsupported alignment size" );
285+ set_context (restore_ctx , & restore_ctx );
246286 return UMF_RESULT_ERROR_INVALID_ALIGNMENT ;
247287 }
248288
289+ umf_result = set_context (restore_ctx , & restore_ctx );
290+ if (umf_result != UMF_RESULT_SUCCESS ) {
291+ return umf_result ;
292+ }
249293 return cu2umf_result (cu_result );
250294}
251295
0 commit comments