Skip to content

Commit b2b2af0

Browse files
committed
use current ctx and dev by default in CUDA prov
1 parent 7727ad1 commit b2b2af0

File tree

5 files changed

+73
-3
lines changed

5 files changed

+73
-3
lines changed

include/umf/providers/provider_cuda.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,17 @@ umf_result_t umfCUDAMemoryProviderParamsDestroy(
3434
/// @brief Set the CUDA context handle in the parameters struct.
3535
/// @param hParams handle to the parameters of the CUDA Memory Provider.
3636
/// @param hContext handle to the CUDA context.
37+
/// @note This function is optional - if no context is set in the params
38+
/// struct, the current context will be used.
3739
/// @return UMF_RESULT_SUCCESS on success or appropriate error code on failure.
3840
umf_result_t umfCUDAMemoryProviderParamsSetContext(
3941
umf_cuda_memory_provider_params_handle_t hParams, void *hContext);
4042

4143
/// @brief Set the CUDA device handle in the parameters struct.
4244
/// @param hParams handle to the parameters of the CUDA Memory Provider.
4345
/// @param hDevice handle to the CUDA device.
46+
/// @note This function is optional - if no device is set in the params
47+
/// struct, the current context's device will be used.
4448
/// @return UMF_RESULT_SUCCESS on success or appropriate error code on failure.
4549
umf_result_t umfCUDAMemoryProviderParamsSetDevice(
4650
umf_cuda_memory_provider_params_handle_t hParams, int hDevice);

src/provider/provider_cuda.c

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,23 @@ umf_result_t umfCUDAMemoryProviderParamsCreate(
260260
return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY;
261261
}
262262

263-
params_data->cuda_context_handle = NULL;
264-
params_data->cuda_device_handle = -1;
263+
// initialize context and device to the current ones
264+
CUcontext current_ctx = NULL;
265+
CUresult cu_result = g_cu_ops.cuCtxGetCurrent(&current_ctx);
266+
if (cu_result == CUDA_SUCCESS) {
267+
params_data->cuda_context_handle = current_ctx;
268+
} else {
269+
params_data->cuda_context_handle = NULL;
270+
}
271+
272+
CUdevice current_device = -1;
273+
cu_result = cuCtxGetDevice(&current_device);
274+
if (cu_result == CUDA_SUCCESS) {
275+
params_data->cuda_device_handle = current_device;
276+
} else {
277+
params_data->cuda_device_handle = -1;
278+
}
279+
265280
params_data->memory_type = UMF_MEMORY_TYPE_UNKNOWN;
266281
params_data->alloc_flags = 0;
267282

test/providers/cuda_helpers.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,18 @@ CUcontext get_mem_context(void *ptr) {
412412
return context;
413413
}
414414

415+
int get_mem_device(void *ptr) {
416+
int device;
417+
CUresult res = libcu_ops.cuPointerGetAttribute(
418+
&device, CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, (CUdeviceptr)ptr);
419+
if (res != CUDA_SUCCESS) {
420+
fprintf(stderr, "cuPointerGetAttribute() failed!\n");
421+
return -1;
422+
}
423+
424+
return device;
425+
}
426+
415427
CUcontext get_current_context() {
416428
CUcontext context;
417429
CUresult res = libcu_ops.cuCtxGetCurrent(&context);

test/providers/cuda_helpers.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ unsigned int get_mem_host_alloc_flags(void *ptr);
4848

4949
CUcontext get_mem_context(void *ptr);
5050

51+
int get_mem_device(void *ptr);
52+
5153
CUcontext get_current_context();
5254

5355
#ifdef __cplusplus

test/providers/provider_cuda.cpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,15 @@ struct umfCUDAProviderTest
142142

143143
memAccessor = nullptr;
144144
expected_context = cudaTestHelper.get_test_context();
145+
expected_device = cudaTestHelper.get_test_device();
145146
params = create_cuda_prov_params(cudaTestHelper.get_test_context(),
146147
cudaTestHelper.get_test_device(),
147148
memory_type, 0 /* alloc flags */);
148149
ASSERT_NE(expected_context, nullptr);
150+
ASSERT_GE(expected_device, 0);
149151

150152
switch (memory_type) {
151153
case UMF_MEMORY_TYPE_DEVICE:
152-
153154
memAccessor = std::make_unique<CUDAMemoryAccessor>(
154155
cudaTestHelper.get_test_context(),
155156
cudaTestHelper.get_test_device());
@@ -178,6 +179,7 @@ struct umfCUDAProviderTest
178179

179180
std::unique_ptr<MemoryAccessor> memAccessor = nullptr;
180181
CUcontext expected_context = nullptr;
182+
int expected_device = -1;
181183
umf_usm_memory_type_t expected_memory_type;
182184
};
183185

@@ -328,6 +330,41 @@ TEST_P(umfCUDAProviderTest, getPageSizeInvalidArgs) {
328330
umfMemoryProviderDestroy(provider);
329331
}
330332

333+
TEST_P(umfCUDAProviderTest, cudaProviderDefaultParams) {
334+
umf_cuda_memory_provider_params_handle_t defaultParams = nullptr;
335+
umf_result_t umf_result = umfCUDAMemoryProviderParamsCreate(&params);
336+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
337+
338+
umf_result = umfCUDAMemoryProviderParamsSetMemoryType(defaultParams,
339+
expected_memory_type);
340+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
341+
342+
// NOTE: we intentionally do not set any context and device params
343+
344+
umf_memory_provider_handle_t provider = nullptr;
345+
umf_result = umfMemoryProviderCreate(umfCUDAMemoryProviderOps(),
346+
defaultParams, &provider);
347+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
348+
ASSERT_NE(provider, nullptr);
349+
350+
// do single alloc and check if the context and device id of allocated
351+
// memory are correct
352+
353+
void *ptr = nullptr;
354+
umf_result = umfMemoryProviderAlloc(provider, 128, 0, &ptr);
355+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
356+
ASSERT_NE(ptr, nullptr);
357+
358+
CUcontext actual_mem_context = get_mem_context(ptr);
359+
ASSERT_EQ(actual_mem_context, expected_context);
360+
361+
int actual_device = get_mem_device(ptr);
362+
ASSERT_EQ(actual_device, expected_device);
363+
364+
umfMemoryProviderDestroy(provider);
365+
umfCUDAMemoryProviderParamsDestroy(defaultParams);
366+
}
367+
331368
TEST_P(umfCUDAProviderTest, cudaProviderNullParams) {
332369
umf_result_t res = umfCUDAMemoryProviderParamsCreate(nullptr);
333370
EXPECT_EQ(res, UMF_RESULT_ERROR_INVALID_ARGUMENT);

0 commit comments

Comments
 (0)