@@ -374,20 +374,6 @@ static umf_result_t cu_memory_provider_initialize(const void *params,
374374 snprintf (cu_provider -> name , sizeof (cu_provider -> name ), "%s" ,
375375 cu_params -> name );
376376
377- // CUDA alloc functions doesn't allow to provide user alignment - get the
378- // minimum one from the driver
379- size_t min_alignment = 0 ;
380- CUmemAllocationProp allocProps = {0 };
381- allocProps .location .type = CU_MEM_LOCATION_TYPE_DEVICE ;
382- allocProps .type = CU_MEM_ALLOCATION_TYPE_PINNED ;
383- allocProps .location .id = cu_params -> cuda_device_handle ;
384- CUresult cu_result = g_cu_ops .cuMemGetAllocationGranularity (
385- & min_alignment , & allocProps , CU_MEM_ALLOC_GRANULARITY_MINIMUM );
386- if (cu_result != CUDA_SUCCESS ) {
387- umf_ba_global_free (cu_provider );
388- return cu2umf_result (cu_result );
389- }
390-
391377 cu_provider -> context = cu_params -> cuda_context_handle ;
392378 cu_provider -> device = cu_params -> cuda_device_handle ;
393379 cu_provider -> memory_type = cu_params -> memory_type ;
@@ -396,14 +382,28 @@ static umf_result_t cu_memory_provider_initialize(const void *params,
396382 // If the memory type is shared (CUDA managed), the allocation flags must
397383 // be set. NOTE: we do not check here if the flags are valid -
398384 // this will be done by CUDA runtime.
399- if (cu_params -> memory_type == UMF_MEMORY_TYPE_SHARED &&
385+ if (cu_provider -> memory_type == UMF_MEMORY_TYPE_SHARED &&
400386 cu_params -> alloc_flags == 0 ) {
401387 // the default setting is CU_MEM_ATTACH_GLOBAL
402388 cu_provider -> alloc_flags = CU_MEM_ATTACH_GLOBAL ;
403389 } else {
404390 cu_provider -> alloc_flags = cu_params -> alloc_flags ;
405391 }
406392
393+ // CUDA alloc functions doesn't allow to provide user alignment - get the
394+ // minimum one from the driver
395+ size_t min_alignment = 0 ;
396+ CUmemAllocationProp allocProps = {0 };
397+ allocProps .location .type = CU_MEM_LOCATION_TYPE_DEVICE ;
398+ allocProps .type = CU_MEM_ALLOCATION_TYPE_PINNED ;
399+ allocProps .location .id = cu_params -> cuda_device_handle ;
400+ CUresult cu_result = g_cu_ops .cuMemGetAllocationGranularity (
401+ & min_alignment , & allocProps , CU_MEM_ALLOC_GRANULARITY_MINIMUM );
402+ if (cu_result != CUDA_SUCCESS ) {
403+ umf_ba_global_free (cu_provider );
404+ return cu2umf_result (cu_result );
405+ }
406+
407407 * provider = cu_provider ;
408408
409409 return UMF_RESULT_SUCCESS ;
0 commit comments