@@ -51,6 +51,8 @@ typedef struct cu_ops_t {
5151
5252 CUresult (* cuGetErrorName )(CUresult error , const char * * pStr );
5353 CUresult (* cuGetErrorString )(CUresult error , const char * * pStr );
54+ CUresult (* cuCtxGetCurrent )(CUcontext * pctx );
55+ CUresult (* cuCtxSetCurrent )(CUcontext ctx );
5456} cu_ops_t ;
5557
5658static cu_ops_t g_cu_ops ;
@@ -117,11 +119,16 @@ static void init_cu_global_state(void) {
117119 utils_get_symbol_addr (0 , "cuGetErrorName" , lib_name );
118120 * (void * * )& g_cu_ops .cuGetErrorString =
119121 utils_get_symbol_addr (0 , "cuGetErrorString" , lib_name );
122+ * (void * * )& g_cu_ops .cuCtxGetCurrent =
123+ utils_get_symbol_addr (0 , "cuCtxGetCurrent" , lib_name );
124+ * (void * * )& g_cu_ops .cuCtxSetCurrent =
125+ utils_get_symbol_addr (0 , "cuCtxSetCurrent" , lib_name );
120126
121127 if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
122128 !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
123129 !g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
124- !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ) {
130+ !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
131+ !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ) {
125132 LOG_ERR ("Required CUDA symbols not found." );
126133 Init_cu_global_state_failed = true;
127134 }
@@ -190,6 +197,31 @@ static void cu_memory_provider_finalize(void *provider) {
190197 umf_ba_global_free (provider );
191198}
192199
200+ /*
201+ * This function is used by the CUDA provider to make sure that
202+ * the required context is set. If the current context is
203+ * not the required one, it will be saved in restore_ctx.
204+ */
205+ static inline umf_result_t set_context (CUcontext required_ctx ,
206+ CUcontext * restore_ctx ) {
207+ CUcontext current_ctx = NULL ;
208+ CUresult cu_result = g_cu_ops .cuCtxGetCurrent (& current_ctx );
209+ if (cu_result != CUDA_SUCCESS ) {
210+ LOG_ERR ("cuCtxGetCurrent() failed." );
211+ return cu2umf_result (cu_result );
212+ }
213+ * restore_ctx = current_ctx ;
214+ if (current_ctx != required_ctx ) {
215+ cu_result = g_cu_ops .cuCtxSetCurrent (required_ctx );
216+ if (cu_result != CUDA_SUCCESS ) {
217+ LOG_ERR ("cuCtxSetCurrent() failed." );
218+ return cu2umf_result (cu_result );
219+ }
220+ }
221+
222+ return UMF_RESULT_SUCCESS ;
223+ }
224+
193225static umf_result_t cu_memory_provider_alloc (void * provider , size_t size ,
194226 size_t alignment ,
195227 void * * resultPtr ) {
@@ -205,6 +237,13 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
205237 return UMF_RESULT_ERROR_NOT_SUPPORTED ;
206238 }
207239
240+ // Remember current context and set the one from the provider
241+ CUcontext restore_ctx = NULL ;
242+ umf_result_t umf_result = set_context (cu_provider -> context , & restore_ctx );
243+ if (umf_result != UMF_RESULT_SUCCESS ) {
244+ return umf_result ;
245+ }
246+
208247 CUresult cu_result = CUDA_SUCCESS ;
209248 switch (cu_provider -> memory_type ) {
210249 case UMF_MEMORY_TYPE_HOST : {
@@ -224,16 +263,21 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
224263 // this shouldn't happen as we check the memory_type settings during
225264 // the initialization
226265 LOG_ERR ("unsupported USM memory type" );
227- return UMF_RESULT_ERROR_UNKNOWN ;
266+ assert (false) ;
228267 }
229268
230269 // check the alignment
231270 if (alignment > 0 && ((uintptr_t )(* resultPtr ) % alignment ) != 0 ) {
232271 cu_memory_provider_free (provider , * resultPtr , size );
233272 LOG_ERR ("unsupported alignment size" );
273+ set_context (restore_ctx , & restore_ctx );
234274 return UMF_RESULT_ERROR_INVALID_ALIGNMENT ;
235275 }
236276
277+ umf_result = set_context (restore_ctx , & restore_ctx );
278+ if (umf_result != UMF_RESULT_SUCCESS ) {
279+ return umf_result ;
280+ }
237281 return cu2umf_result (cu_result );
238282}
239283
0 commit comments