@@ -315,6 +315,62 @@ TEST_P(umfCUDAProviderTest, cudaProviderNullParams) {
315315 EXPECT_EQ (res, UMF_RESULT_ERROR_INVALID_ARGUMENT);
316316}
317317
318+ TEST_P (umfCUDAProviderTest, multiContext) {
319+ CUdevice device;
320+ int ret = get_cuda_device (&device);
321+ ASSERT_EQ (ret, 0 );
322+
323+ // create two CUDA contexts and two providers
324+ CUcontext ctx1, ctx2;
325+ ret = create_context (device, &ctx1);
326+ ASSERT_EQ (ret, 0 );
327+ ret = create_context (device, &ctx2);
328+ ASSERT_EQ (ret, 0 );
329+
330+ cuda_params_unique_handle_t params1 =
331+ create_cuda_prov_params (ctx1, device, UMF_MEMORY_TYPE_HOST);
332+ ASSERT_NE (params1, nullptr );
333+ umf_memory_provider_handle_t provider1;
334+ umf_result_t umf_result = umfMemoryProviderCreate (
335+ umfCUDAMemoryProviderOps (), params1.get (), &provider1);
336+ ASSERT_EQ (umf_result, UMF_RESULT_SUCCESS);
337+ ASSERT_NE (provider1, nullptr );
338+
339+ cuda_params_unique_handle_t params2 =
340+ create_cuda_prov_params (ctx2, device, UMF_MEMORY_TYPE_HOST);
341+ ASSERT_NE (params2, nullptr );
342+ umf_memory_provider_handle_t provider2;
343+ umf_result = umfMemoryProviderCreate (umfCUDAMemoryProviderOps (),
344+ params2.get (), &provider2);
345+ ASSERT_EQ (umf_result, UMF_RESULT_SUCCESS);
346+ ASSERT_NE (provider2, nullptr );
347+
348+ // use the providers
349+ // allocate from 1, then from 2, then free 1, then free 2
350+ void *ptr1, *ptr2;
351+ const int size = 128 ;
352+ umf_result = umfMemoryProviderAlloc (provider1, size, 0 , &ptr1);
353+ ASSERT_EQ (umf_result, UMF_RESULT_SUCCESS);
354+ ASSERT_NE (ptr1, nullptr );
355+
356+ umf_result = umfMemoryProviderAlloc (provider2, size, 0 , &ptr2);
357+ ASSERT_EQ (umf_result, UMF_RESULT_SUCCESS);
358+ ASSERT_NE (ptr2, nullptr );
359+
360+ umf_result = umfMemoryProviderFree (provider1, ptr1, size);
361+ ASSERT_EQ (umf_result, UMF_RESULT_SUCCESS);
362+
363+ umf_result = umfMemoryProviderFree (provider2, ptr1, size);
364+ ASSERT_EQ (umf_result, UMF_RESULT_SUCCESS);
365+
366+ // cleanup
367+ umfMemoryProviderDestroy (provider1);
368+ ret = destroy_context (ctx1);
369+ ASSERT_EQ (ret, 0 );
370+ ret = destroy_context (ctx2);
371+ ASSERT_EQ (ret, 0 );
372+ }
373+
318374// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool
319375
320376CUDATestHelper cudaTestHelper;
0 commit comments