@@ -21,10 +21,8 @@ using namespace umf_test;
2121
2222class CUDAMemoryAccessor : public MemoryAccessor {
2323 public:
24- void init (CUcontext hContext, CUdevice hDevice) {
25- hDevice_ = hDevice;
26- hContext_ = hContext;
27- }
24+ CUDAMemoryAccessor (CUcontext hContext, CUdevice hDevice)
25+ : hDevice_(hDevice), hContext_(hContext) {}
2826
2927 void fill (void *ptr, size_t size, const void *pattern,
3028 size_t pattern_size) {
@@ -53,7 +51,7 @@ class CUDAMemoryAccessor : public MemoryAccessor {
5351};
5452
5553using CUDAProviderTestParams =
56- std::tuple<umf_usm_memory_type_t , MemoryAccessor *>;
54+ std::tuple<cuda_memory_provider_params_t , MemoryAccessor *>;
5755
5856struct umfCUDAProviderTest
5957 : umf_test::test,
@@ -62,14 +60,9 @@ struct umfCUDAProviderTest
6260 void SetUp () override {
6361 test::SetUp ();
6462
65- auto [memory_type , accessor] = this ->GetParam ();
66- params = create_cuda_prov_params (memory_type) ;
63+ auto [cuda_params , accessor] = this ->GetParam ();
64+ params = cuda_params ;
6765 memAccessor = accessor;
68- if (memory_type == UMF_MEMORY_TYPE_DEVICE) {
69- ((CUDAMemoryAccessor *)memAccessor)
70- ->init ((CUcontext)params.cuda_context_handle ,
71- params.cuda_device_handle );
72- }
7366 }
7467
7568 void TearDown () override {
@@ -168,15 +161,24 @@ TEST_P(umfCUDAProviderTest, allocInvalidSize) {
168161
169162// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool
170163
171- CUDAMemoryAccessor cuAccessor;
164+ cuda_memory_provider_params_t cuParams_device_memory =
165+ create_cuda_prov_params (UMF_MEMORY_TYPE_DEVICE);
166+ cuda_memory_provider_params_t cuParams_shared_memory =
167+ create_cuda_prov_params (UMF_MEMORY_TYPE_SHARED);
168+ cuda_memory_provider_params_t cuParams_host_memory =
169+ create_cuda_prov_params (UMF_MEMORY_TYPE_HOST);
170+
171+ CUDAMemoryAccessor
172+ cuAccessor ((CUcontext)cuParams_device_memory.cuda_context_handle,
173+ (CUdevice)cuParams_device_memory.cuda_device_handle);
172174HostMemoryAccessor hostAccessor;
173175
174176INSTANTIATE_TEST_SUITE_P (
175177 umfCUDAProviderTestSuite, umfCUDAProviderTest,
176178 ::testing::Values (
177- CUDAProviderTestParams{UMF_MEMORY_TYPE_DEVICE , &cuAccessor},
178- CUDAProviderTestParams{UMF_MEMORY_TYPE_SHARED , &hostAccessor},
179- CUDAProviderTestParams{UMF_MEMORY_TYPE_HOST , &hostAccessor}));
179+ CUDAProviderTestParams{cuParams_device_memory , &cuAccessor},
180+ CUDAProviderTestParams{cuParams_shared_memory , &hostAccessor},
181+ CUDAProviderTestParams{cuParams_host_memory , &hostAccessor}));
180182
181183// TODO: add IPC API
182184GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST (umfIpcTest);
@@ -185,5 +187,5 @@ INSTANTIATE_TEST_SUITE_P(umfCUDAProviderTestSuite, umfIpcTest,
185187 ::testing::Values(ipcTestParams{
186188 umfProxyPoolOps(), nullptr,
187189 umfCUDAMemoryProviderOps(),
188- &cuParams_device_memory, &l0Accessor }));
190+ &cuParams_device_memory, &cuAccessor }));
189191*/
0 commit comments