Skip to content

Commit 45c6767

Browse files
committed
Updated CUDA provider tests
1 parent 3a434fb commit 45c6767

File tree

2 files changed

+28
-25
lines changed

2 files changed

+28
-25
lines changed

test/providers/cuda_helpers.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct libcu_ops {
2626
CUresult (*cuMemFreeHost)(void *p);
2727
CUresult (*cuMemsetD32)(CUdeviceptr dstDevice, unsigned int pattern,
2828
size_t size);
29-
CUresult (*cuMemcpyDtoH)(void *dstHost, CUdeviceptr srcDevice, size_t size);
29+
CUresult (*cuMemcpy)(CUdeviceptr dst, CUdeviceptr src, size_t size);
3030
CUresult (*cuPointerGetAttributes)(unsigned int numAttributes,
3131
CUpointer_attribute *attributes,
3232
void **data, CUdeviceptr ptr);
@@ -116,10 +116,10 @@ int InitCUDAOps() {
116116
fprintf(stderr, "cuMemsetD32_v2 symbol not found in %s\n", lib_name);
117117
return -1;
118118
}
119-
*(void **)&libcu_ops.cuMemcpyDtoH =
120-
utils_get_symbol_addr(cuDlHandle.get(), "cuMemcpyDtoH_v2", lib_name);
121-
if (libcu_ops.cuMemcpyDtoH == nullptr) {
122-
fprintf(stderr, "cuMemcpyDtoH_v2 symbol not found in %s\n", lib_name);
119+
*(void **)&libcu_ops.cuMemcpy =
120+
utils_get_symbol_addr(cuDlHandle.get(), "cuMemcpy_v2", lib_name);
121+
if (libcu_ops.cuMemcpy == nullptr) {
122+
fprintf(stderr, "cuMemcpy_v2 symbol not found in %s\n", lib_name);
123123
return -1;
124124
}
125125
*(void **)&libcu_ops.cuPointerGetAttributes = utils_get_symbol_addr(
@@ -147,7 +147,7 @@ int InitCUDAOps() {
147147
libcu_ops.cuMemFree = cuMemFree;
148148
libcu_ops.cuMemFreeHost = cuMemFreeHost;
149149
libcu_ops.cuMemsetD32 = cuMemsetD32;
150-
libcu_ops.cuMemcpyDtoH = cuMemcpyDtoH;
150+
libcu_ops.cuMemcpy = cuMemcpy;
151151
libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
152152

153153
return 0;
@@ -193,9 +193,10 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr,
193193
(void)device;
194194

195195
int ret = 0;
196-
CUresult res = libcu_ops.cuMemcpyDtoH(dst_ptr, (CUdeviceptr)src_ptr, size);
196+
CUresult res =
197+
libcu_ops.cuMemcpy((CUdeviceptr)dst_ptr, (CUdeviceptr)src_ptr, size);
197198
if (res != CUDA_SUCCESS) {
198-
fprintf(stderr, "cuMemcpyDtoH() failed!\n");
199+
fprintf(stderr, "cuMemcpy() failed!\n");
199200
return -1;
200201
}
201202

test/providers/provider_cuda.cpp

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@ using namespace umf_test;
2121

2222
class CUDAMemoryAccessor : public MemoryAccessor {
2323
public:
24-
void init(CUcontext hContext, CUdevice hDevice) {
25-
hDevice_ = hDevice;
26-
hContext_ = hContext;
27-
}
24+
CUDAMemoryAccessor(CUcontext hContext, CUdevice hDevice)
25+
: hDevice_(hDevice), hContext_(hContext) {}
2826

2927
void fill(void *ptr, size_t size, const void *pattern,
3028
size_t pattern_size) {
@@ -53,7 +51,7 @@ class CUDAMemoryAccessor : public MemoryAccessor {
5351
};
5452

5553
using CUDAProviderTestParams =
56-
std::tuple<umf_usm_memory_type_t, MemoryAccessor *>;
54+
std::tuple<cuda_memory_provider_params_t, MemoryAccessor *>;
5755

5856
struct umfCUDAProviderTest
5957
: umf_test::test,
@@ -62,14 +60,9 @@ struct umfCUDAProviderTest
6260
void SetUp() override {
6361
test::SetUp();
6462

65-
auto [memory_type, accessor] = this->GetParam();
66-
params = create_cuda_prov_params(memory_type);
63+
auto [cuda_params, accessor] = this->GetParam();
64+
params = cuda_params;
6765
memAccessor = accessor;
68-
if (memory_type == UMF_MEMORY_TYPE_DEVICE) {
69-
((CUDAMemoryAccessor *)memAccessor)
70-
->init((CUcontext)params.cuda_context_handle,
71-
params.cuda_device_handle);
72-
}
7366
}
7467

7568
void TearDown() override {
@@ -168,15 +161,24 @@ TEST_P(umfCUDAProviderTest, allocInvalidSize) {
168161

169162
// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool
170163

171-
CUDAMemoryAccessor cuAccessor;
164+
cuda_memory_provider_params_t cuParams_device_memory =
165+
create_cuda_prov_params(UMF_MEMORY_TYPE_DEVICE);
166+
cuda_memory_provider_params_t cuParams_shared_memory =
167+
create_cuda_prov_params(UMF_MEMORY_TYPE_SHARED);
168+
cuda_memory_provider_params_t cuParams_host_memory =
169+
create_cuda_prov_params(UMF_MEMORY_TYPE_HOST);
170+
171+
CUDAMemoryAccessor
172+
cuAccessor((CUcontext)cuParams_device_memory.cuda_context_handle,
173+
(CUdevice)cuParams_device_memory.cuda_device_handle);
172174
HostMemoryAccessor hostAccessor;
173175

174176
INSTANTIATE_TEST_SUITE_P(
175177
umfCUDAProviderTestSuite, umfCUDAProviderTest,
176178
::testing::Values(
177-
CUDAProviderTestParams{UMF_MEMORY_TYPE_DEVICE, &cuAccessor},
178-
CUDAProviderTestParams{UMF_MEMORY_TYPE_SHARED, &hostAccessor},
179-
CUDAProviderTestParams{UMF_MEMORY_TYPE_HOST, &hostAccessor}));
179+
CUDAProviderTestParams{cuParams_device_memory, &cuAccessor},
180+
CUDAProviderTestParams{cuParams_shared_memory, &hostAccessor},
181+
CUDAProviderTestParams{cuParams_host_memory, &hostAccessor}));
180182

181183
// TODO: add IPC API
182184
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(umfIpcTest);
@@ -185,5 +187,5 @@ INSTANTIATE_TEST_SUITE_P(umfCUDAProviderTestSuite, umfIpcTest,
185187
::testing::Values(ipcTestParams{
186188
umfProxyPoolOps(), nullptr,
187189
umfCUDAMemoryProviderOps(),
188-
&cuParams_device_memory, &l0Accessor}));
190+
&cuParams_device_memory, &cuAccessor}));
189191
*/

0 commit comments

Comments
 (0)