@@ -374,20 +374,6 @@ static umf_result_t cu_memory_provider_initialize(const void *params,
374
374
snprintf (cu_provider -> name , sizeof (cu_provider -> name ), "%s" ,
375
375
cu_params -> name );
376
376
377
- // CUDA alloc functions doesn't allow to provide user alignment - get the
378
- // minimum one from the driver
379
- size_t min_alignment = 0 ;
380
- CUmemAllocationProp allocProps = {0 };
381
- allocProps .location .type = CU_MEM_LOCATION_TYPE_DEVICE ;
382
- allocProps .type = CU_MEM_ALLOCATION_TYPE_PINNED ;
383
- allocProps .location .id = cu_params -> cuda_device_handle ;
384
- CUresult cu_result = g_cu_ops .cuMemGetAllocationGranularity (
385
- & min_alignment , & allocProps , CU_MEM_ALLOC_GRANULARITY_MINIMUM );
386
- if (cu_result != CUDA_SUCCESS ) {
387
- umf_ba_global_free (cu_provider );
388
- return cu2umf_result (cu_result );
389
- }
390
-
391
377
cu_provider -> context = cu_params -> cuda_context_handle ;
392
378
cu_provider -> device = cu_params -> cuda_device_handle ;
393
379
cu_provider -> memory_type = cu_params -> memory_type ;
@@ -396,14 +382,28 @@ static umf_result_t cu_memory_provider_initialize(const void *params,
396
382
// If the memory type is shared (CUDA managed), the allocation flags must
397
383
// be set. NOTE: we do not check here if the flags are valid -
398
384
// this will be done by CUDA runtime.
399
- if (cu_params -> memory_type == UMF_MEMORY_TYPE_SHARED &&
385
+ if (cu_provider -> memory_type == UMF_MEMORY_TYPE_SHARED &&
400
386
cu_params -> alloc_flags == 0 ) {
401
387
// the default setting is CU_MEM_ATTACH_GLOBAL
402
388
cu_provider -> alloc_flags = CU_MEM_ATTACH_GLOBAL ;
403
389
} else {
404
390
cu_provider -> alloc_flags = cu_params -> alloc_flags ;
405
391
}
406
392
393
+ // CUDA alloc functions doesn't allow to provide user alignment - get the
394
+ // minimum one from the driver
395
+ size_t min_alignment = 0 ;
396
+ CUmemAllocationProp allocProps = {0 };
397
+ allocProps .location .type = CU_MEM_LOCATION_TYPE_DEVICE ;
398
+ allocProps .type = CU_MEM_ALLOCATION_TYPE_PINNED ;
399
+ allocProps .location .id = cu_params -> cuda_device_handle ;
400
+ CUresult cu_result = g_cu_ops .cuMemGetAllocationGranularity (
401
+ & min_alignment , & allocProps , CU_MEM_ALLOC_GRANULARITY_MINIMUM );
402
+ if (cu_result != CUDA_SUCCESS ) {
403
+ umf_ba_global_free (cu_provider );
404
+ return cu2umf_result (cu_result );
405
+ }
406
+
407
407
* provider = cu_provider ;
408
408
409
409
return UMF_RESULT_SUCCESS ;
0 commit comments