@@ -377,18 +377,20 @@ static umf_result_t cu_memory_provider_initialize(const void *params,
377377 cu_provider -> context = cu_params -> cuda_context_handle ;
378378 cu_provider -> device = cu_params -> cuda_device_handle ;
379379 cu_provider -> memory_type = cu_params -> memory_type ;
380- cu_provider -> min_alignment = min_alignment ;
380+ cu_provider -> alloc_flags = cu_params -> alloc_flags ;
381381
382- // If the memory type is shared (CUDA managed), the allocation flags must
383- // be set. NOTE: we do not check here if the flags are valid -
384- // this will be done by CUDA runtime.
385- if (cu_provider -> memory_type == UMF_MEMORY_TYPE_SHARED &&
386- cu_params -> alloc_flags == 0 ) {
387- // the default setting is CU_MEM_ATTACH_GLOBAL
388- cu_provider -> alloc_flags = CU_MEM_ATTACH_GLOBAL ;
389- } else {
390- cu_provider -> alloc_flags = cu_params -> alloc_flags ;
391- }
382+ * provider = cu_provider ;
383+
384+ return UMF_RESULT_SUCCESS ;
385+ }
386+
387+ static umf_result_t cu_memory_provider_finalize (void * provider ) {
388+ umf_ba_global_free (provider );
389+ return UMF_RESULT_SUCCESS ;
390+ }
391+
392+ static umf_result_t cu_memory_provider_post_initialize (void * provider ) {
393+ cu_memory_provider_t * cu_provider = (cu_memory_provider_t * )provider ;
392394
393395 // CUDA alloc functions doesn't allow to provide user alignment - get the
394396 // minimum one from the driver
@@ -404,18 +406,17 @@ static umf_result_t cu_memory_provider_initialize(const void *params,
404406 return cu2umf_result (cu_result );
405407 }
406408
407- * provider = cu_provider ;
408-
409- return UMF_RESULT_SUCCESS ;
410- }
411409
412- static umf_result_t cu_memory_provider_finalize (void * provider ) {
413- umf_ba_global_free (provider );
414- return UMF_RESULT_SUCCESS ;
415- }
410+ cu_provider -> min_alignment = min_alignment ;
416411
417- static umf_result_t cu_memory_provider_post_initialize (void * provider ) {
418- (void )provider ;
412+ // If the memory type is shared (CUDA managed), the allocation flags must
413+ // be set. NOTE: we do not check here if the flags are valid -
414+ // this will be done by CUDA runtime.
415+ if (cu_provider -> memory_type == UMF_MEMORY_TYPE_SHARED &&
416+ cu_provider -> alloc_flags == 0 ) {
417+ // the default setting is CU_MEM_ATTACH_GLOBAL
418+ cu_provider -> alloc_flags = CU_MEM_ATTACH_GLOBAL ;
419+ }
419420 // For initial version, just return success
420421 return UMF_RESULT_SUCCESS ;
421422}
0 commit comments