Skip to content

Commit 47414fd

Browse files
committed
CUDA provider split
1 parent 05c77a9 commit 47414fd

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

src/provider/provider_cuda.c

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -374,20 +374,6 @@ static umf_result_t cu_memory_provider_initialize(const void *params,
374374
snprintf(cu_provider->name, sizeof(cu_provider->name), "%s",
375375
cu_params->name);
376376

377-
// CUDA alloc functions doesn't allow to provide user alignment - get the
378-
// minimum one from the driver
379-
size_t min_alignment = 0;
380-
CUmemAllocationProp allocProps = {0};
381-
allocProps.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
382-
allocProps.type = CU_MEM_ALLOCATION_TYPE_PINNED;
383-
allocProps.location.id = cu_params->cuda_device_handle;
384-
CUresult cu_result = g_cu_ops.cuMemGetAllocationGranularity(
385-
&min_alignment, &allocProps, CU_MEM_ALLOC_GRANULARITY_MINIMUM);
386-
if (cu_result != CUDA_SUCCESS) {
387-
umf_ba_global_free(cu_provider);
388-
return cu2umf_result(cu_result);
389-
}
390-
391377
cu_provider->context = cu_params->cuda_context_handle;
392378
cu_provider->device = cu_params->cuda_device_handle;
393379
cu_provider->memory_type = cu_params->memory_type;
@@ -396,14 +382,28 @@ static umf_result_t cu_memory_provider_initialize(const void *params,
396382
// If the memory type is shared (CUDA managed), the allocation flags must
397383
// be set. NOTE: we do not check here if the flags are valid -
398384
// this will be done by CUDA runtime.
399-
if (cu_params->memory_type == UMF_MEMORY_TYPE_SHARED &&
385+
if (cu_provider->memory_type == UMF_MEMORY_TYPE_SHARED &&
400386
cu_params->alloc_flags == 0) {
401387
// the default setting is CU_MEM_ATTACH_GLOBAL
402388
cu_provider->alloc_flags = CU_MEM_ATTACH_GLOBAL;
403389
} else {
404390
cu_provider->alloc_flags = cu_params->alloc_flags;
405391
}
406392

393+
// CUDA alloc functions doesn't allow to provide user alignment - get the
394+
// minimum one from the driver
395+
size_t min_alignment = 0;
396+
CUmemAllocationProp allocProps = {0};
397+
allocProps.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
398+
allocProps.type = CU_MEM_ALLOCATION_TYPE_PINNED;
399+
allocProps.location.id = cu_params->cuda_device_handle;
400+
CUresult cu_result = g_cu_ops.cuMemGetAllocationGranularity(
401+
&min_alignment, &allocProps, CU_MEM_ALLOC_GRANULARITY_MINIMUM);
402+
if (cu_result != CUDA_SUCCESS) {
403+
umf_ba_global_free(cu_provider);
404+
return cu2umf_result(cu_result);
405+
}
406+
407407
*provider = cu_provider;
408408

409409
return UMF_RESULT_SUCCESS;

0 commit comments

Comments
 (0)