@@ -21,10 +21,8 @@ using namespace umf_test;
2121
2222class CUDAMemoryAccessor : public MemoryAccessor {
2323 public:
24- void init (CUcontext hContext, CUdevice hDevice) {
25- hDevice_ = hDevice;
26- hContext_ = hContext;
27- }
24+ CUDAMemoryAccessor (CUcontext hContext, CUdevice hDevice)
25+ : hDevice_(hDevice), hContext_(hContext) {}
2826
2927 void fill (void *ptr, size_t size, const void *pattern,
3028 size_t pattern_size) {
@@ -53,7 +51,7 @@ class CUDAMemoryAccessor : public MemoryAccessor {
5351};
5452
5553using CUDAProviderTestParams =
56- std::tuple<umf_usm_memory_type_t , MemoryAccessor *>;
54+ std::tuple<cuda_memory_provider_params_t , MemoryAccessor *>;
5755
5856struct umfCUDAProviderTest
5957 : umf_test::test,
@@ -62,23 +60,12 @@ struct umfCUDAProviderTest
6260 void SetUp () override {
6361 test::SetUp ();
6462
65- auto [memory_type , accessor] = this ->GetParam ();
66- params = create_cuda_prov_params (memory_type) ;
63+ auto [cuda_params , accessor] = this ->GetParam ();
64+ params = cuda_params ;
6765 memAccessor = accessor;
68- if (memory_type == UMF_MEMORY_TYPE_DEVICE) {
69- ((CUDAMemoryAccessor *)memAccessor)
70- ->init ((CUcontext)params.cuda_context_handle ,
71- params.cuda_device_handle );
72- }
7366 }
7467
75- void TearDown () override {
76- if (params.cuda_context_handle ) {
77- int ret = destroy_context ((CUcontext)params.cuda_context_handle );
78- ASSERT_EQ (ret, 0 );
79- }
80- test::TearDown ();
81- }
68+ void TearDown () override { test::TearDown (); }
8269
8370 cuda_memory_provider_params_t params;
8471 MemoryAccessor *memAccessor = nullptr ;
@@ -150,33 +137,28 @@ TEST_P(umfCUDAProviderTest, allocInvalidSize) {
150137 umf_result = umfMemoryProviderAlloc (provider, 0 , 0 , &ptr);
151138 ASSERT_EQ (umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
152139 }
153-
154- // destroy context and try to alloc some memory
155- destroy_context ((CUcontext)params.cuda_context_handle );
156- params.cuda_context_handle = 0 ;
157- umf_result = umfMemoryProviderAlloc (provider, 128 , 0 , &ptr);
158- ASSERT_EQ (umf_result, UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC);
159-
160- const char *message;
161- int32_t error;
162- umfMemoryProviderGetLastNativeError (provider, &message, &error);
163- ASSERT_EQ (error, CUDA_ERROR_INVALID_CONTEXT);
164- const char *expected_message =
165- " CUDA_ERROR_INVALID_CONTEXT - invalid device context" ;
166- ASSERT_EQ (strncmp (message, expected_message, strlen (expected_message)), 0 );
167140}
168141
169142// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool
170143
171- CUDAMemoryAccessor cuAccessor;
144+ cuda_memory_provider_params_t cuParams_device_memory =
145+ create_cuda_prov_params (UMF_MEMORY_TYPE_DEVICE);
146+ cuda_memory_provider_params_t cuParams_shared_memory =
147+ create_cuda_prov_params (UMF_MEMORY_TYPE_SHARED);
148+ cuda_memory_provider_params_t cuParams_host_memory =
149+ create_cuda_prov_params (UMF_MEMORY_TYPE_HOST);
150+
151+ CUDAMemoryAccessor
152+ cuAccessor ((CUcontext)cuParams_device_memory.cuda_context_handle,
153+ (CUdevice)cuParams_device_memory.cuda_device_handle);
172154HostMemoryAccessor hostAccessor;
173155
174156INSTANTIATE_TEST_SUITE_P (
175157 umfCUDAProviderTestSuite, umfCUDAProviderTest,
176158 ::testing::Values (
177- CUDAProviderTestParams{UMF_MEMORY_TYPE_DEVICE , &cuAccessor},
178- CUDAProviderTestParams{UMF_MEMORY_TYPE_SHARED , &hostAccessor},
179- CUDAProviderTestParams{UMF_MEMORY_TYPE_HOST , &hostAccessor}));
159+ CUDAProviderTestParams{cuParams_device_memory , &cuAccessor},
160+ CUDAProviderTestParams{cuParams_shared_memory , &hostAccessor},
161+ CUDAProviderTestParams{cuParams_host_memory , &hostAccessor}));
180162
181163// TODO: add IPC API
182164GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST (umfIpcTest);
@@ -185,5 +167,5 @@ INSTANTIATE_TEST_SUITE_P(umfCUDAProviderTestSuite, umfIpcTest,
185167 ::testing::Values(ipcTestParams{
186168 umfProxyPoolOps(), nullptr,
187169 umfCUDAMemoryProviderOps(),
188- &cuParams_device_memory, &l0Accessor }));
170+ &cuParams_device_memory, &cuAccessor }));
189171*/
0 commit comments