Skip to content

Commit 8dcfa8a

Browse files
committed
Split continue
1 parent 947cd91 commit 8dcfa8a

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

src/provider/provider_cuda.c

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -377,45 +377,46 @@ static umf_result_t cu_memory_provider_initialize(const void *params,
377377
cu_provider->context = cu_params->cuda_context_handle;
378378
cu_provider->device = cu_params->cuda_device_handle;
379379
cu_provider->memory_type = cu_params->memory_type;
380-
cu_provider->min_alignment = min_alignment;
380+
cu_provider->alloc_flags = cu_params->alloc_flags;
381381

382-
// If the memory type is shared (CUDA managed), the allocation flags must
383-
// be set. NOTE: we do not check here if the flags are valid -
384-
// this will be done by CUDA runtime.
385-
if (cu_provider->memory_type == UMF_MEMORY_TYPE_SHARED &&
386-
cu_params->alloc_flags == 0) {
387-
// the default setting is CU_MEM_ATTACH_GLOBAL
388-
cu_provider->alloc_flags = CU_MEM_ATTACH_GLOBAL;
389-
} else {
390-
cu_provider->alloc_flags = cu_params->alloc_flags;
391-
}
382+
*provider = cu_provider;
383+
384+
return UMF_RESULT_SUCCESS;
385+
}
386+
387+
static umf_result_t cu_memory_provider_finalize(void *provider) {
388+
umf_ba_global_free(provider);
389+
return UMF_RESULT_SUCCESS;
390+
}
391+
392+
static umf_result_t cu_memory_provider_post_initialize(void *provider) {
393+
cu_memory_provider_t *cu_provider = (cu_memory_provider_t *)provider;
392394

393395
// CUDA alloc functions doesn't allow to provide user alignment - get the
394396
// minimum one from the driver
395397
size_t min_alignment = 0;
396398
CUmemAllocationProp allocProps = {0};
397399
allocProps.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
398400
allocProps.type = CU_MEM_ALLOCATION_TYPE_PINNED;
399-
allocProps.location.id = cu_params->cuda_device_handle;
401+
allocProps.location.id = cu_provider->device;
400402
CUresult cu_result = g_cu_ops.cuMemGetAllocationGranularity(
401403
&min_alignment, &allocProps, CU_MEM_ALLOC_GRANULARITY_MINIMUM);
402404
if (cu_result != CUDA_SUCCESS) {
403405
umf_ba_global_free(cu_provider);
404406
return cu2umf_result(cu_result);
405407
}
406408

407-
*provider = cu_provider;
408-
409-
return UMF_RESULT_SUCCESS;
410-
}
411409

412-
static umf_result_t cu_memory_provider_finalize(void *provider) {
413-
umf_ba_global_free(provider);
414-
return UMF_RESULT_SUCCESS;
415-
}
410+
cu_provider->min_alignment = min_alignment;
416411

417-
static umf_result_t cu_memory_provider_post_initialize(void *provider) {
418-
(void)provider;
412+
// If the memory type is shared (CUDA managed), the allocation flags must
413+
// be set. NOTE: we do not check here if the flags are valid -
414+
// this will be done by CUDA runtime.
415+
if (cu_provider->memory_type == UMF_MEMORY_TYPE_SHARED &&
416+
cu_provider->alloc_flags == 0) {
417+
// the default setting is CU_MEM_ATTACH_GLOBAL
418+
cu_provider->alloc_flags = CU_MEM_ATTACH_GLOBAL;
419+
}
419420
// For initial version, just return success
420421
return UMF_RESULT_SUCCESS;
421422
}

0 commit comments

Comments
 (0)