Skip to content

Commit 6873226

Browse files
committed
Refactor CUDA tests
1 parent d007147 commit 6873226

File tree

1 file changed

+45
-59
lines changed

1 file changed

+45
-59
lines changed

test/providers/provider_cuda.cpp

Lines changed: 45 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ create_cuda_prov_params(CUcontext context, CUdevice device,
8383
return params;
8484
}
8585

86+
umf_result_t destroyCuParams(void *params) {
87+
return umfCUDAMemoryProviderParamsDestroy(
88+
(umf_cuda_memory_provider_params_handle_t)params);
89+
}
90+
8691
class CUDAMemoryAccessor : public MemoryAccessor {
8792
public:
8893
CUDAMemoryAccessor(CUcontext hContext, CUdevice hDevice)
@@ -114,47 +119,62 @@ class CUDAMemoryAccessor : public MemoryAccessor {
114119
CUcontext hContext_;
115120
};
116121

117-
typedef void *(*pfnProviderParamsCreate)();
118-
typedef umf_result_t (*pfnProviderParamsDestroy)(void *);
119-
120-
using CUDAProviderTestParams =
121-
std::tuple<pfnProviderParamsCreate, pfnProviderParamsDestroy, CUcontext,
122-
umf_usm_memory_type_t, MemoryAccessor *>;
123-
124122
struct umfCUDAProviderTest
125123
: umf_test::test,
126-
::testing::WithParamInterface<CUDAProviderTestParams> {
124+
::testing::WithParamInterface<umf_usm_memory_type_t> {
127125

128126
void SetUp() override {
129127
test::SetUp();
130128

131-
auto [params_create, params_destroy, cu_context, memory_type,
132-
accessor] = this->GetParam();
129+
umf_usm_memory_type_t memory_type = this->GetParam();
133130

134131
params = nullptr;
135-
if (params_create) {
136-
params = (umf_cuda_memory_provider_params_handle_t)params_create();
132+
memAccessor = nullptr;
133+
expected_context = cudaTestHelper.get_test_context();
134+
135+
ASSERT_NE(expected_context, nullptr);
136+
137+
switch (memory_type) {
138+
case UMF_MEMORY_TYPE_DEVICE:
139+
params = create_cuda_prov_params(cudaTestHelper.get_test_context(),
140+
cudaTestHelper.get_test_device(),
141+
memory_type);
142+
memAccessor = std::make_unique<CUDAMemoryAccessor>(
143+
cudaTestHelper.get_test_context(),
144+
cudaTestHelper.get_test_device());
145+
break;
146+
case UMF_MEMORY_TYPE_SHARED:
147+
params = create_cuda_prov_params(cudaTestHelper.get_test_context(),
148+
cudaTestHelper.get_test_device(),
149+
memory_type);
150+
memAccessor = std::make_unique<HostMemoryAccessor>();
151+
break;
152+
case UMF_MEMORY_TYPE_HOST:
153+
params = create_cuda_prov_params(cudaTestHelper.get_test_context(),
154+
cudaTestHelper.get_test_device(),
155+
memory_type);
156+
memAccessor = std::make_unique<HostMemoryAccessor>();
157+
break;
158+
case UMF_MEMORY_TYPE_UNKNOWN:
159+
break;
137160
}
138-
paramsDestroy = params_destroy;
139161

140-
memAccessor = accessor;
141-
expected_context = cu_context;
142162
expected_memory_type = memory_type;
143163
}
144164

145165
void TearDown() override {
146-
if (paramsDestroy) {
147-
paramsDestroy(params);
166+
if (params) {
167+
destroyCuParams(params);
148168
}
149169

150170
test::TearDown();
151171
}
152172

153-
umf_cuda_memory_provider_params_handle_t params;
154-
pfnProviderParamsDestroy paramsDestroy = nullptr;
173+
CUDATestHelper cudaTestHelper;
174+
umf_cuda_memory_provider_params_handle_t params = nullptr;
155175

156-
MemoryAccessor *memAccessor = nullptr;
157-
CUcontext expected_context;
176+
std::unique_ptr<MemoryAccessor> memAccessor = nullptr;
177+
CUcontext expected_context = nullptr;
158178
umf_usm_memory_type_t expected_memory_type;
159179
};
160180

@@ -391,44 +411,10 @@ TEST_P(umfCUDAProviderTest, multiContext) {
391411

392412
// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool
393413

394-
CUDATestHelper cudaTestHelper;
395-
396-
void *createCuParamsDeviceMemory() {
397-
return create_cuda_prov_params(cudaTestHelper.get_test_context(),
398-
cudaTestHelper.get_test_device(),
399-
UMF_MEMORY_TYPE_DEVICE);
400-
}
401-
void *createCuParamsSharedMemory() {
402-
return create_cuda_prov_params(cudaTestHelper.get_test_context(),
403-
cudaTestHelper.get_test_device(),
404-
UMF_MEMORY_TYPE_SHARED);
405-
}
406-
void *createCuParamsHostMemory() {
407-
return create_cuda_prov_params(cudaTestHelper.get_test_context(),
408-
cudaTestHelper.get_test_device(),
409-
UMF_MEMORY_TYPE_HOST);
410-
}
411-
412-
umf_result_t destroyCuParams(void *params) {
413-
return umfCUDAMemoryProviderParamsDestroy(
414-
(umf_cuda_memory_provider_params_handle_t)params);
415-
}
416-
417-
CUDAMemoryAccessor cuAccessor(cudaTestHelper.get_test_context(),
418-
cudaTestHelper.get_test_device());
419-
HostMemoryAccessor hostAccessor;
420-
INSTANTIATE_TEST_SUITE_P(
421-
umfCUDAProviderTestSuite, umfCUDAProviderTest,
422-
::testing::Values(
423-
CUDAProviderTestParams{createCuParamsDeviceMemory, destroyCuParams,
424-
cudaTestHelper.get_test_context(),
425-
UMF_MEMORY_TYPE_DEVICE, &cuAccessor},
426-
CUDAProviderTestParams{createCuParamsSharedMemory, destroyCuParams,
427-
cudaTestHelper.get_test_context(),
428-
UMF_MEMORY_TYPE_SHARED, &hostAccessor},
429-
CUDAProviderTestParams{createCuParamsHostMemory, destroyCuParams,
430-
cudaTestHelper.get_test_context(),
431-
UMF_MEMORY_TYPE_HOST, &hostAccessor}));
414+
INSTANTIATE_TEST_SUITE_P(umfCUDAProviderTestSuite, umfCUDAProviderTest,
415+
::testing::Values(UMF_MEMORY_TYPE_DEVICE,
416+
UMF_MEMORY_TYPE_SHARED,
417+
UMF_MEMORY_TYPE_HOST));
432418

433419
// TODO: add IPC API
434420
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(umfIpcTest);

0 commit comments

Comments
 (0)