@@ -21,10 +21,7 @@ using namespace umf_test;
2121
2222class CUDAMemoryAccessor : public MemoryAccessor {
2323 public:
24- void init (CUcontext hContext, CUdevice hDevice) {
25- hDevice_ = hDevice;
26- hContext_ = hContext;
27- }
24+ CUDAMemoryAccessor (CUcontext hContext, CUdevice hDevice): hDevice_(hDevice), hContext_(hContext) {}
2825
2926 void fill (void *ptr, size_t size, const void *pattern,
3027 size_t pattern_size) {
@@ -53,7 +50,7 @@ class CUDAMemoryAccessor : public MemoryAccessor {
5350};
5451
5552using CUDAProviderTestParams =
56- std::tuple<umf_usm_memory_type_t , MemoryAccessor *>;
53+ std::tuple<cuda_memory_provider_params_t , MemoryAccessor *>;
5754
5855struct umfCUDAProviderTest
5956 : umf_test::test,
@@ -62,21 +59,12 @@ struct umfCUDAProviderTest
6259 void SetUp () override {
6360 test::SetUp ();
6461
65- auto [memory_type , accessor] = this ->GetParam ();
66- params = create_cuda_prov_params (memory_type) ;
62+ auto [cuda_params , accessor] = this ->GetParam ();
63+ params = cuda_params ;
6764 memAccessor = accessor;
68- if (memory_type == UMF_MEMORY_TYPE_DEVICE) {
69- ((CUDAMemoryAccessor *)memAccessor)
70- ->init ((CUcontext)params.cuda_context_handle ,
71- params.cuda_device_handle );
72- }
7365 }
7466
7567 void TearDown () override {
76- if (params.cuda_context_handle ) {
77- int ret = destroy_context ((CUcontext)params.cuda_context_handle );
78- ASSERT_EQ (ret, 0 );
79- }
8068 test::TearDown ();
8169 }
8270
@@ -150,33 +138,24 @@ TEST_P(umfCUDAProviderTest, allocInvalidSize) {
150138 umf_result = umfMemoryProviderAlloc (provider, 0 , 0 , &ptr);
151139 ASSERT_EQ (umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
152140 }
153-
154- // destroy context and try to alloc some memory
155- destroy_context ((CUcontext)params.cuda_context_handle );
156- params.cuda_context_handle = 0 ;
157- umf_result = umfMemoryProviderAlloc (provider, 128 , 0 , &ptr);
158- ASSERT_EQ (umf_result, UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC);
159-
160- const char *message;
161- int32_t error;
162- umfMemoryProviderGetLastNativeError (provider, &message, &error);
163- ASSERT_EQ (error, CUDA_ERROR_INVALID_CONTEXT);
164- const char *expected_message =
165- " CUDA_ERROR_INVALID_CONTEXT - invalid device context" ;
166- ASSERT_EQ (strncmp (message, expected_message, strlen (expected_message)), 0 );
167141}
168142
169143// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool
170144
171- CUDAMemoryAccessor cuAccessor;
145+ cuda_memory_provider_params_t cuParams_device_memory = create_cuda_prov_params(UMF_MEMORY_TYPE_DEVICE);
146+ cuda_memory_provider_params_t cuParams_shared_memory = create_cuda_prov_params(UMF_MEMORY_TYPE_SHARED);
147+ cuda_memory_provider_params_t cuParams_host_memory = create_cuda_prov_params(UMF_MEMORY_TYPE_HOST);
148+
149+ CUDAMemoryAccessor cuAccessor ((CUcontext)cuParams_device_memory.cuda_context_handle,
150+ (CUdevice)cuParams_device_memory.cuda_device_handle);
172151HostMemoryAccessor hostAccessor;
173152
174153INSTANTIATE_TEST_SUITE_P (
175154 umfCUDAProviderTestSuite, umfCUDAProviderTest,
176155 ::testing::Values (
177- CUDAProviderTestParams{UMF_MEMORY_TYPE_DEVICE , &cuAccessor},
178- CUDAProviderTestParams{UMF_MEMORY_TYPE_SHARED , &hostAccessor},
179- CUDAProviderTestParams{UMF_MEMORY_TYPE_HOST , &hostAccessor}));
156+ CUDAProviderTestParams{cuParams_device_memory , &cuAccessor},
157+ CUDAProviderTestParams{cuParams_shared_memory , &hostAccessor},
158+ CUDAProviderTestParams{cuParams_host_memory , &hostAccessor}));
180159
181160// TODO: add IPC API
182161GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST (umfIpcTest);
@@ -185,5 +164,5 @@ INSTANTIATE_TEST_SUITE_P(umfCUDAProviderTestSuite, umfIpcTest,
185164 ::testing::Values(ipcTestParams{
186165 umfProxyPoolOps(), nullptr,
187166 umfCUDAMemoryProviderOps(),
188- &cuParams_device_memory, &l0Accessor }));
167+ &cuParams_device_memory, &cuAccessor }));
189168*/
0 commit comments