Skip to content

Commit 5c7f175

Browse files
committed
Updated CUDA provider tests
1 parent d4a9701 commit 5c7f175

File tree

2 files changed

+29
-46
lines changed

2 files changed

+29
-46
lines changed

test/providers/cuda_helpers.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct libcu_ops {
2626
CUresult (*cuMemFreeHost)(void *p);
2727
CUresult (*cuMemsetD32)(CUdeviceptr dstDevice, unsigned int pattern,
2828
size_t size);
29-
CUresult (*cuMemcpyDtoH)(void *dstHost, CUdeviceptr srcDevice, size_t size);
29+
CUresult (*cuMemcpy)(CUdeviceptr dst, CUdeviceptr src, size_t size);
3030
CUresult (*cuPointerGetAttributes)(unsigned int numAttributes,
3131
CUpointer_attribute *attributes,
3232
void **data, CUdeviceptr ptr);
@@ -116,10 +116,10 @@ int InitCUDAOps() {
116116
fprintf(stderr, "cuMemsetD32_v2 symbol not found in %s\n", lib_name);
117117
return -1;
118118
}
119-
*(void **)&libcu_ops.cuMemcpyDtoH =
120-
utils_get_symbol_addr(cuDlHandle.get(), "cuMemcpyDtoH_v2", lib_name);
121-
if (libcu_ops.cuMemcpyDtoH == nullptr) {
122-
fprintf(stderr, "cuMemcpyDtoH_v2 symbol not found in %s\n", lib_name);
119+
*(void **)&libcu_ops.cuMemcpy =
120+
utils_get_symbol_addr(cuDlHandle.get(), "cuMemcpy", lib_name);
121+
if (libcu_ops.cuMemcpy == nullptr) {
122+
fprintf(stderr, "cuMemcpy symbol not found in %s\n", lib_name);
123123
return -1;
124124
}
125125
*(void **)&libcu_ops.cuPointerGetAttributes = utils_get_symbol_addr(
@@ -147,7 +147,7 @@ int InitCUDAOps() {
147147
libcu_ops.cuMemFree = cuMemFree;
148148
libcu_ops.cuMemFreeHost = cuMemFreeHost;
149149
libcu_ops.cuMemsetD32 = cuMemsetD32;
150-
libcu_ops.cuMemcpyDtoH = cuMemcpyDtoH;
150+
libcu_ops.cuMemcpy = cuMemcpy;
151151
libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
152152

153153
return 0;
@@ -193,9 +193,10 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
193193
(void)device;
194194

195195
int ret = 0;
196-
CUresult res = libcu_ops.cuMemcpyDtoH(dst_ptr, (CUdeviceptr)src_ptr, size);
196+
CUresult res =
197+
libcu_ops.cuMemcpy((CUdeviceptr)dst_ptr, (CUdeviceptr)src_ptr, size);
197198
if (res != CUDA_SUCCESS) {
198-
fprintf(stderr, "cuMemcpyDtoH() failed!\n");
199+
fprintf(stderr, "cuMemcpy() failed!\n");
199200
return -1;
200201
}
201202

test/providers/provider_cuda.cpp

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@ using namespace umf_test;
2121

2222
class CUDAMemoryAccessor : public MemoryAccessor {
2323
public:
24-
void init(CUcontext hContext, CUdevice hDevice) {
25-
hDevice_ = hDevice;
26-
hContext_ = hContext;
27-
}
24+
CUDAMemoryAccessor(CUcontext hContext, CUdevice hDevice)
25+
: hDevice_(hDevice), hContext_(hContext) {}
2826

2927
void fill(void *ptr, size_t size, const void *pattern,
3028
size_t pattern_size) {
@@ -53,7 +51,7 @@ class CUDAMemoryAccessor : public MemoryAccessor {
5351
};
5452

5553
using CUDAProviderTestParams =
56-
std::tuple<umf_usm_memory_type_t, MemoryAccessor *>;
54+
std::tuple<cuda_memory_provider_params_t, MemoryAccessor *>;
5755

5856
struct umfCUDAProviderTest
5957
: umf_test::test,
@@ -62,23 +60,12 @@ struct umfCUDAProviderTest
6260
void SetUp() override {
6361
test::SetUp();
6462

65-
auto [memory_type, accessor] = this->GetParam();
66-
params = create_cuda_prov_params(memory_type);
63+
auto [cuda_params, accessor] = this->GetParam();
64+
params = cuda_params;
6765
memAccessor = accessor;
68-
if (memory_type == UMF_MEMORY_TYPE_DEVICE) {
69-
((CUDAMemoryAccessor *)memAccessor)
70-
->init((CUcontext)params.cuda_context_handle,
71-
params.cuda_device_handle);
72-
}
7366
}
7467

75-
void TearDown() override {
76-
if (params.cuda_context_handle) {
77-
int ret = destroy_context((CUcontext)params.cuda_context_handle);
78-
ASSERT_EQ(ret, 0);
79-
}
80-
test::TearDown();
81-
}
68+
void TearDown() override { test::TearDown(); }
8269

8370
cuda_memory_provider_params_t params;
8471
MemoryAccessor *memAccessor = nullptr;
@@ -150,33 +137,28 @@ TEST_P(umfCUDAProviderTest, allocInvalidSize) {
150137
umf_result = umfMemoryProviderAlloc(provider, 0, 0, &ptr);
151138
ASSERT_EQ(umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
152139
}
153-
154-
// destroy context and try to alloc some memory
155-
destroy_context((CUcontext)params.cuda_context_handle);
156-
params.cuda_context_handle = 0;
157-
umf_result = umfMemoryProviderAlloc(provider, 128, 0, &ptr);
158-
ASSERT_EQ(umf_result, UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC);
159-
160-
const char *message;
161-
int32_t error;
162-
umfMemoryProviderGetLastNativeError(provider, &message, &error);
163-
ASSERT_EQ(error, CUDA_ERROR_INVALID_CONTEXT);
164-
const char *expected_message =
165-
"CUDA_ERROR_INVALID_CONTEXT - invalid device context";
166-
ASSERT_EQ(strncmp(message, expected_message, strlen(expected_message)), 0);
167140
}
168141

169142
// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool
170143

171-
CUDAMemoryAccessor cuAccessor;
144+
cuda_memory_provider_params_t cuParams_device_memory =
145+
create_cuda_prov_params(UMF_MEMORY_TYPE_DEVICE);
146+
cuda_memory_provider_params_t cuParams_shared_memory =
147+
create_cuda_prov_params(UMF_MEMORY_TYPE_SHARED);
148+
cuda_memory_provider_params_t cuParams_host_memory =
149+
create_cuda_prov_params(UMF_MEMORY_TYPE_HOST);
150+
151+
CUDAMemoryAccessor
152+
cuAccessor((CUcontext)cuParams_device_memory.cuda_context_handle,
153+
(CUdevice)cuParams_device_memory.cuda_device_handle);
172154
HostMemoryAccessor hostAccessor;
173155

174156
INSTANTIATE_TEST_SUITE_P(
175157
umfCUDAProviderTestSuite, umfCUDAProviderTest,
176158
::testing::Values(
177-
CUDAProviderTestParams{UMF_MEMORY_TYPE_DEVICE, &cuAccessor},
178-
CUDAProviderTestParams{UMF_MEMORY_TYPE_SHARED, &hostAccessor},
179-
CUDAProviderTestParams{UMF_MEMORY_TYPE_HOST, &hostAccessor}));
159+
CUDAProviderTestParams{cuParams_device_memory, &cuAccessor},
160+
CUDAProviderTestParams{cuParams_shared_memory, &hostAccessor},
161+
CUDAProviderTestParams{cuParams_host_memory, &hostAccessor}));
180162

181163
// TODO: add IPC API
182164
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(umfIpcTest);
@@ -185,5 +167,5 @@ INSTANTIATE_TEST_SUITE_P(umfCUDAProviderTestSuite, umfIpcTest,
185167
::testing::Values(ipcTestParams{
186168
umfProxyPoolOps(), nullptr,
187169
umfCUDAMemoryProviderOps(),
188-
&cuParams_device_memory, &l0Accessor}));
170+
&cuParams_device_memory, &cuAccessor}));
189171
*/

0 commit comments

Comments
 (0)