|
14 | 14 |
|
15 | 15 | #if defined(UMF_NO_CUDA_PROVIDER)
|
16 | 16 |
|
| 17 | +umf_result_t umfCUDAMemoryProviderParamsCreate( |
| 18 | + umf_cuda_memory_provider_params_handle_t *hParams) { |
| 19 | + (void)hParams; |
| 20 | + return UMF_RESULT_ERROR_NOT_SUPPORTED; |
| 21 | +} |
| 22 | + |
| 23 | +umf_result_t umfCUDAMemoryProviderParamsDestroy( |
| 24 | + umf_cuda_memory_provider_params_handle_t hParams) { |
| 25 | + (void)hParams; |
| 26 | + return UMF_RESULT_ERROR_NOT_SUPPORTED; |
| 27 | +} |
| 28 | + |
| 29 | +umf_result_t umfCUDAMemoryProviderParamsSetContext( |
| 30 | + umf_cuda_memory_provider_params_handle_t hParams, void *hContext) { |
| 31 | + (void)hParams; |
| 32 | + (void)hContext; |
| 33 | + return UMF_RESULT_ERROR_NOT_SUPPORTED; |
| 34 | +} |
| 35 | + |
| 36 | +umf_result_t umfCUDAMemoryProviderParamsSetDevice( |
| 37 | + umf_cuda_memory_provider_params_handle_t hParams, int hDevice) { |
| 38 | + (void)hParams; |
| 39 | + (void)hDevice; |
| 40 | + return UMF_RESULT_ERROR_NOT_SUPPORTED; |
| 41 | +} |
| 42 | + |
| 43 | +umf_result_t umfCUDAMemoryProviderParamsSetMemoryType( |
| 44 | + umf_cuda_memory_provider_params_handle_t hParams, |
| 45 | + umf_usm_memory_type_t memoryType) { |
| 46 | + (void)hParams; |
| 47 | + (void)memoryType; |
| 48 | + return UMF_RESULT_ERROR_NOT_SUPPORTED; |
| 49 | +} |
| 50 | + |
17 | 51 | umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {
|
18 | 52 | // not supported
|
19 | 53 | return NULL;
|
@@ -48,6 +82,13 @@ typedef struct cu_memory_provider_t {
|
48 | 82 | size_t min_alignment;
|
49 | 83 | } cu_memory_provider_t;
|
50 | 84 |
|
| 85 | +// CUDA Memory Provider settings struct |
| 86 | +typedef struct umf_cuda_memory_provider_params_t { |
| 87 | + void *cuda_context_handle; ///< Handle to the CUDA context |
| 88 | + int cuda_device_handle; ///< Handle to the CUDA device |
| 89 | + umf_usm_memory_type_t memory_type; ///< Allocation memory type |
| 90 | +} umf_cuda_memory_provider_params_t; |
| 91 | + |
51 | 92 | typedef struct cu_ops_t {
|
52 | 93 | CUresult (*cuMemGetAllocationGranularity)(
|
53 | 94 | size_t *granularity, const CUmemAllocationProp *prop,
|
@@ -158,14 +199,81 @@ static void init_cu_global_state(void) {
|
158 | 199 | }
|
159 | 200 | }
|
160 | 201 |
|
| 202 | +umf_result_t umfCUDAMemoryProviderParamsCreate( |
| 203 | + umf_cuda_memory_provider_params_handle_t *hParams) { |
| 204 | + if (!hParams) { |
| 205 | + LOG_ERR("CUDA Memory Provider params handle is NULL"); |
| 206 | + return UMF_RESULT_ERROR_INVALID_ARGUMENT; |
| 207 | + } |
| 208 | + |
| 209 | + umf_cuda_memory_provider_params_handle_t params_data = |
| 210 | + umf_ba_global_alloc(sizeof(umf_cuda_memory_provider_params_t)); |
| 211 | + if (!params_data) { |
| 212 | + LOG_ERR("Cannot allocate memory for CUDA Memory Provider params"); |
| 213 | + return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY; |
| 214 | + } |
| 215 | + |
| 216 | + params_data->cuda_context_handle = NULL; |
| 217 | + params_data->cuda_device_handle = -1; |
| 218 | + params_data->memory_type = UMF_MEMORY_TYPE_UNKNOWN; |
| 219 | + |
| 220 | + *hParams = params_data; |
| 221 | + |
| 222 | + return UMF_RESULT_SUCCESS; |
| 223 | +} |
| 224 | + |
| 225 | +umf_result_t umfCUDAMemoryProviderParamsDestroy( |
| 226 | + umf_cuda_memory_provider_params_handle_t hParams) { |
| 227 | + umf_ba_global_free(hParams); |
| 228 | + |
| 229 | + return UMF_RESULT_SUCCESS; |
| 230 | +} |
| 231 | + |
| 232 | +umf_result_t umfCUDAMemoryProviderParamsSetContext( |
| 233 | + umf_cuda_memory_provider_params_handle_t hParams, void *hContext) { |
| 234 | + if (!hParams) { |
| 235 | + LOG_ERR("CUDA Memory Provider params handle is NULL"); |
| 236 | + return UMF_RESULT_ERROR_INVALID_ARGUMENT; |
| 237 | + } |
| 238 | + |
| 239 | + hParams->cuda_context_handle = hContext; |
| 240 | + |
| 241 | + return UMF_RESULT_SUCCESS; |
| 242 | +} |
| 243 | + |
| 244 | +umf_result_t umfCUDAMemoryProviderParamsSetDevice( |
| 245 | + umf_cuda_memory_provider_params_handle_t hParams, int hDevice) { |
| 246 | + if (!hParams) { |
| 247 | + LOG_ERR("CUDA Memory Provider params handle is NULL"); |
| 248 | + return UMF_RESULT_ERROR_INVALID_ARGUMENT; |
| 249 | + } |
| 250 | + |
| 251 | + hParams->cuda_device_handle = hDevice; |
| 252 | + |
| 253 | + return UMF_RESULT_SUCCESS; |
| 254 | +} |
| 255 | + |
| 256 | +umf_result_t umfCUDAMemoryProviderParamsSetMemoryType( |
| 257 | + umf_cuda_memory_provider_params_handle_t hParams, |
| 258 | + umf_usm_memory_type_t memoryType) { |
| 259 | + if (!hParams) { |
| 260 | + LOG_ERR("CUDA Memory Provider params handle is NULL"); |
| 261 | + return UMF_RESULT_ERROR_INVALID_ARGUMENT; |
| 262 | + } |
| 263 | + |
| 264 | + hParams->memory_type = memoryType; |
| 265 | + |
| 266 | + return UMF_RESULT_SUCCESS; |
| 267 | +} |
| 268 | + |
161 | 269 | static umf_result_t cu_memory_provider_initialize(void *params,
|
162 | 270 | void **provider) {
|
163 | 271 | if (params == NULL) {
|
164 | 272 | return UMF_RESULT_ERROR_INVALID_ARGUMENT;
|
165 | 273 | }
|
166 | 274 |
|
167 |
| - cuda_memory_provider_params_t *cu_params = |
168 |
| - (cuda_memory_provider_params_t *)params; |
| 275 | + umf_cuda_memory_provider_params_handle_t cu_params = |
| 276 | + (umf_cuda_memory_provider_params_handle_t)params; |
169 | 277 |
|
170 | 278 | if (cu_params->memory_type == UMF_MEMORY_TYPE_UNKNOWN ||
|
171 | 279 | cu_params->memory_type > UMF_MEMORY_TYPE_SHARED) {
|
|
0 commit comments