Skip to content

Commit c599bb8

Browse files
committed
Updated CUDA provider tests
1 parent f3a5cab commit c599bb8

File tree

2 files changed

+22
-43
lines changed

2 files changed

+22
-43
lines changed

test/providers/cuda_helpers.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct libcu_ops {
2626
CUresult (*cuMemFreeHost)(void *p);
2727
CUresult (*cuMemsetD32)(CUdeviceptr dstDevice, unsigned int pattern,
2828
size_t size);
29-
CUresult (*cuMemcpyDtoH)(void *dstHost, CUdeviceptr srcDevice, size_t size);
29+
CUresult (*cuMemcpy)( CUdeviceptr dst, CUdeviceptr src, size_t size);
3030
CUresult (*cuPointerGetAttributes)(unsigned int numAttributes,
3131
CUpointer_attribute *attributes,
3232
void **data, CUdeviceptr ptr);
@@ -116,10 +116,10 @@ int InitCUDAOps() {
116116
fprintf(stderr, "cuMemsetD32_v2 symbol not found in %s\n", lib_name);
117117
return -1;
118118
}
119-
*(void **)&libcu_ops.cuMemcpyDtoH =
120-
utils_get_symbol_addr(cuDlHandle.get(), "cuMemcpyDtoH_v2", lib_name);
121-
if (libcu_ops.cuMemcpyDtoH == nullptr) {
122-
fprintf(stderr, "cuMemcpyDtoH_v2 symbol not found in %s\n", lib_name);
119+
*(void **)&libcu_ops.cuMemcpy =
120+
utils_get_symbol_addr(cuDlHandle.get(), "cuMemcpy", lib_name);
121+
if (libcu_ops.cuMemcpy == nullptr) {
122+
fprintf(stderr, "cuMemcpy symbol not found in %s\n", lib_name);
123123
return -1;
124124
}
125125
*(void **)&libcu_ops.cuPointerGetAttributes = utils_get_symbol_addr(
@@ -147,7 +147,7 @@ int InitCUDAOps() {
147147
libcu_ops.cuMemFree = cuMemFree;
148148
libcu_ops.cuMemFreeHost = cuMemFreeHost;
149149
libcu_ops.cuMemsetD32 = cuMemsetD32;
150-
libcu_ops.cuMemcpyDtoH = cuMemcpyDtoH;
150+
libcu_ops.cuMemcpy = cuMemcpy;
151151
libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
152152

153153
return 0;
@@ -193,9 +193,9 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
193193
(void)device;
194194

195195
int ret = 0;
196-
CUresult res = libcu_ops.cuMemcpyDtoH(dst_ptr, (CUdeviceptr)src_ptr, size);
196+
CUresult res = libcu_ops.cuMemcpy((CUdeviceptr)dst_ptr, (CUdeviceptr)src_ptr, size);
197197
if (res != CUDA_SUCCESS) {
198-
fprintf(stderr, "cuMemcpyDtoH() failed!\n");
198+
fprintf(stderr, "cuMemcpy() failed!\n");
199199
return -1;
200200
}
201201

test/providers/provider_cuda.cpp

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@ using namespace umf_test;
2121

2222
class CUDAMemoryAccessor : public MemoryAccessor {
2323
public:
24-
void init(CUcontext hContext, CUdevice hDevice) {
25-
hDevice_ = hDevice;
26-
hContext_ = hContext;
27-
}
24+
CUDAMemoryAccessor(CUcontext hContext, CUdevice hDevice): hDevice_(hDevice), hContext_(hContext) {}
2825

2926
void fill(void *ptr, size_t size, const void *pattern,
3027
size_t pattern_size) {
@@ -53,7 +50,7 @@ class CUDAMemoryAccessor : public MemoryAccessor {
5350
};
5451

5552
using CUDAProviderTestParams =
56-
std::tuple<umf_usm_memory_type_t, MemoryAccessor *>;
53+
std::tuple<cuda_memory_provider_params_t, MemoryAccessor *>;
5754

5855
struct umfCUDAProviderTest
5956
: umf_test::test,
@@ -62,21 +59,12 @@ struct umfCUDAProviderTest
6259
void SetUp() override {
6360
test::SetUp();
6461

65-
auto [memory_type, accessor] = this->GetParam();
66-
params = create_cuda_prov_params(memory_type);
62+
auto [cuda_params, accessor] = this->GetParam();
63+
params = cuda_params;
6764
memAccessor = accessor;
68-
if (memory_type == UMF_MEMORY_TYPE_DEVICE) {
69-
((CUDAMemoryAccessor *)memAccessor)
70-
->init((CUcontext)params.cuda_context_handle,
71-
params.cuda_device_handle);
72-
}
7365
}
7466

7567
void TearDown() override {
76-
if (params.cuda_context_handle) {
77-
int ret = destroy_context((CUcontext)params.cuda_context_handle);
78-
ASSERT_EQ(ret, 0);
79-
}
8068
test::TearDown();
8169
}
8270

@@ -150,33 +138,24 @@ TEST_P(umfCUDAProviderTest, allocInvalidSize) {
150138
umf_result = umfMemoryProviderAlloc(provider, 0, 0, &ptr);
151139
ASSERT_EQ(umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
152140
}
153-
154-
// destroy context and try to alloc some memory
155-
destroy_context((CUcontext)params.cuda_context_handle);
156-
params.cuda_context_handle = 0;
157-
umf_result = umfMemoryProviderAlloc(provider, 128, 0, &ptr);
158-
ASSERT_EQ(umf_result, UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC);
159-
160-
const char *message;
161-
int32_t error;
162-
umfMemoryProviderGetLastNativeError(provider, &message, &error);
163-
ASSERT_EQ(error, CUDA_ERROR_INVALID_CONTEXT);
164-
const char *expected_message =
165-
"CUDA_ERROR_INVALID_CONTEXT - invalid device context";
166-
ASSERT_EQ(strncmp(message, expected_message, strlen(expected_message)), 0);
167141
}
168142

169143
// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool
170144

171-
CUDAMemoryAccessor cuAccessor;
145+
cuda_memory_provider_params_t cuParams_device_memory = create_cuda_prov_params(UMF_MEMORY_TYPE_DEVICE);
146+
cuda_memory_provider_params_t cuParams_shared_memory = create_cuda_prov_params(UMF_MEMORY_TYPE_SHARED);
147+
cuda_memory_provider_params_t cuParams_host_memory = create_cuda_prov_params(UMF_MEMORY_TYPE_HOST);
148+
149+
CUDAMemoryAccessor cuAccessor((CUcontext)cuParams_device_memory.cuda_context_handle,
150+
(CUdevice)cuParams_device_memory.cuda_device_handle);
172151
HostMemoryAccessor hostAccessor;
173152

174153
INSTANTIATE_TEST_SUITE_P(
175154
umfCUDAProviderTestSuite, umfCUDAProviderTest,
176155
::testing::Values(
177-
CUDAProviderTestParams{UMF_MEMORY_TYPE_DEVICE, &cuAccessor},
178-
CUDAProviderTestParams{UMF_MEMORY_TYPE_SHARED, &hostAccessor},
179-
CUDAProviderTestParams{UMF_MEMORY_TYPE_HOST, &hostAccessor}));
156+
CUDAProviderTestParams{cuParams_device_memory, &cuAccessor},
157+
CUDAProviderTestParams{cuParams_shared_memory, &hostAccessor},
158+
CUDAProviderTestParams{cuParams_host_memory, &hostAccessor}));
180159

181160
// TODO: add IPC API
182161
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(umfIpcTest);
@@ -185,5 +164,5 @@ INSTANTIATE_TEST_SUITE_P(umfCUDAProviderTestSuite, umfIpcTest,
185164
::testing::Values(ipcTestParams{
186165
umfProxyPoolOps(), nullptr,
187166
umfCUDAMemoryProviderOps(),
188-
&cuParams_device_memory, &l0Accessor}));
167+
&cuParams_device_memory, &cuAccessor}));
189168
*/

0 commit comments

Comments
 (0)