@@ -83,6 +83,11 @@ create_cuda_prov_params(CUcontext context, CUdevice device,
8383 return params;
8484}
8585
86+ umf_result_t destroyCuParams (void *params) {
87+ return umfCUDAMemoryProviderParamsDestroy (
88+ (umf_cuda_memory_provider_params_handle_t )params);
89+ }
90+
8691class CUDAMemoryAccessor : public MemoryAccessor {
8792 public:
8893 CUDAMemoryAccessor (CUcontext hContext, CUdevice hDevice)
@@ -114,47 +119,62 @@ class CUDAMemoryAccessor : public MemoryAccessor {
114119 CUcontext hContext_;
115120};
116121
117- typedef void *(*pfnProviderParamsCreate)();
118- typedef umf_result_t (*pfnProviderParamsDestroy)(void *);
119-
120- using CUDAProviderTestParams =
121- std::tuple<pfnProviderParamsCreate, pfnProviderParamsDestroy, CUcontext,
122- umf_usm_memory_type_t , MemoryAccessor *>;
123-
124122struct umfCUDAProviderTest
125123 : umf_test::test,
126- ::testing::WithParamInterface<CUDAProviderTestParams > {
124+ ::testing::WithParamInterface<umf_usm_memory_type_t > {
127125
128126 void SetUp () override {
129127 test::SetUp ();
130128
131- auto [params_create, params_destroy, cu_context, memory_type,
132- accessor] = this ->GetParam ();
129+ umf_usm_memory_type_t memory_type = this ->GetParam ();
133130
134131 params = nullptr ;
135- if (params_create) {
136- params = (umf_cuda_memory_provider_params_handle_t )params_create ();
132+ memAccessor = nullptr ;
133+ expected_context = cudaTestHelper.get_test_context ();
134+
135+ ASSERT_NE (expected_context, nullptr );
136+
137+ switch (memory_type) {
138+ case UMF_MEMORY_TYPE_DEVICE:
139+ params = create_cuda_prov_params (cudaTestHelper.get_test_context (),
140+ cudaTestHelper.get_test_device (),
141+ memory_type);
142+ memAccessor = std::make_unique<CUDAMemoryAccessor>(
143+ cudaTestHelper.get_test_context (),
144+ cudaTestHelper.get_test_device ());
145+ break ;
146+ case UMF_MEMORY_TYPE_SHARED:
147+ params = create_cuda_prov_params (cudaTestHelper.get_test_context (),
148+ cudaTestHelper.get_test_device (),
149+ memory_type);
150+ memAccessor = std::make_unique<HostMemoryAccessor>();
151+ break ;
152+ case UMF_MEMORY_TYPE_HOST:
153+ params = create_cuda_prov_params (cudaTestHelper.get_test_context (),
154+ cudaTestHelper.get_test_device (),
155+ memory_type);
156+ memAccessor = std::make_unique<HostMemoryAccessor>();
157+ break ;
158+ case UMF_MEMORY_TYPE_UNKNOWN:
159+ break ;
137160 }
138- paramsDestroy = params_destroy;
139161
140- memAccessor = accessor;
141- expected_context = cu_context;
142162 expected_memory_type = memory_type;
143163 }
144164
145165 void TearDown () override {
146- if (paramsDestroy ) {
147- paramsDestroy (params);
166+ if (params ) {
167+ destroyCuParams (params);
148168 }
149169
150170 test::TearDown ();
151171 }
152172
153- umf_cuda_memory_provider_params_handle_t params ;
154- pfnProviderParamsDestroy paramsDestroy = nullptr ;
173+ CUDATestHelper cudaTestHelper ;
174+ umf_cuda_memory_provider_params_handle_t params = nullptr ;
155175
156- MemoryAccessor * memAccessor = nullptr ;
157- CUcontext expected_context;
176+ std::unique_ptr< MemoryAccessor> memAccessor = nullptr ;
177+ CUcontext expected_context = nullptr ;
158178 umf_usm_memory_type_t expected_memory_type;
159179};
160180
@@ -391,44 +411,10 @@ TEST_P(umfCUDAProviderTest, multiContext) {
391411
392412// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool
393413
394- CUDATestHelper cudaTestHelper;
395-
396- void *createCuParamsDeviceMemory () {
397- return create_cuda_prov_params (cudaTestHelper.get_test_context (),
398- cudaTestHelper.get_test_device (),
399- UMF_MEMORY_TYPE_DEVICE);
400- }
401- void *createCuParamsSharedMemory () {
402- return create_cuda_prov_params (cudaTestHelper.get_test_context (),
403- cudaTestHelper.get_test_device (),
404- UMF_MEMORY_TYPE_SHARED);
405- }
406- void *createCuParamsHostMemory () {
407- return create_cuda_prov_params (cudaTestHelper.get_test_context (),
408- cudaTestHelper.get_test_device (),
409- UMF_MEMORY_TYPE_HOST);
410- }
411-
412- umf_result_t destroyCuParams (void *params) {
413- return umfCUDAMemoryProviderParamsDestroy (
414- (umf_cuda_memory_provider_params_handle_t )params);
415- }
416-
417- CUDAMemoryAccessor cuAccessor (cudaTestHelper.get_test_context(),
418- cudaTestHelper.get_test_device());
419- HostMemoryAccessor hostAccessor;
420- INSTANTIATE_TEST_SUITE_P (
421- umfCUDAProviderTestSuite, umfCUDAProviderTest,
422- ::testing::Values (
423- CUDAProviderTestParams{createCuParamsDeviceMemory, destroyCuParams,
424- cudaTestHelper.get_test_context (),
425- UMF_MEMORY_TYPE_DEVICE, &cuAccessor},
426- CUDAProviderTestParams{createCuParamsSharedMemory, destroyCuParams,
427- cudaTestHelper.get_test_context (),
428- UMF_MEMORY_TYPE_SHARED, &hostAccessor},
429- CUDAProviderTestParams{createCuParamsHostMemory, destroyCuParams,
430- cudaTestHelper.get_test_context (),
431- UMF_MEMORY_TYPE_HOST, &hostAccessor}));
414+ INSTANTIATE_TEST_SUITE_P (umfCUDAProviderTestSuite, umfCUDAProviderTest,
415+ ::testing::Values (UMF_MEMORY_TYPE_DEVICE,
416+ UMF_MEMORY_TYPE_SHARED,
417+ UMF_MEMORY_TYPE_HOST));
432418
433419// TODO: add IPC API
434420GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST (umfIpcTest);
0 commit comments