@@ -383,6 +383,47 @@ TEST_P(umfCUDAProviderTest, cudaProviderNullParams) {
383383 EXPECT_EQ (res, UMF_RESULT_ERROR_INVALID_ARGUMENT);
384384}
385385
386+ TEST_P (umfCUDAProviderTest, cudaProviderInvalidCreate) {
387+ CUdevice device;
388+ int ret = get_cuda_device (&device);
389+ ASSERT_EQ (ret, 0 );
390+
391+ CUcontext ctx;
392+ ret = create_context (device, &ctx);
393+ ASSERT_EQ (ret, 0 );
394+
395+ // wrong memory type
396+ umf_cuda_memory_provider_params_handle_t params_wrong_memtype =
397+ create_cuda_prov_params (ctx, device, (umf_usm_memory_type_t )0xFFFF , 0 );
398+ ASSERT_NE (params_wrong_memtype, nullptr );
399+ umf_memory_provider_handle_t provider;
400+ umf_result_t umf_result = umfMemoryProviderCreate (
401+ umfCUDAMemoryProviderOps (), params_wrong_memtype, &provider);
402+ ASSERT_EQ (umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
403+ umf_result = umfCUDAMemoryProviderParamsDestroy (params_wrong_memtype);
404+ ASSERT_EQ (umf_result, UMF_RESULT_SUCCESS);
405+
406+ // wrong context
407+ umf_cuda_memory_provider_params_handle_t params_wrong_ctx =
408+ create_cuda_prov_params (NULL , device, UMF_MEMORY_TYPE_HOST, 0 );
409+ ASSERT_NE (params_wrong_ctx, nullptr );
410+ umf_result = umfMemoryProviderCreate (umfCUDAMemoryProviderOps (),
411+ params_wrong_ctx, &provider);
412+ ASSERT_EQ (umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
413+ umf_result = umfCUDAMemoryProviderParamsDestroy (params_wrong_ctx);
414+ ASSERT_EQ (umf_result, UMF_RESULT_SUCCESS);
415+
416+ // wrong device
417+ umf_cuda_memory_provider_params_handle_t params_wrong_device =
418+ create_cuda_prov_params (ctx, (CUdevice)-1 , UMF_MEMORY_TYPE_HOST, 0 );
419+ ASSERT_NE (params_wrong_device, nullptr );
420+ umf_result = umfMemoryProviderCreate (umfCUDAMemoryProviderOps (),
421+ params_wrong_device, &provider);
422+ ASSERT_EQ (umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
423+ umf_result = umfCUDAMemoryProviderParamsDestroy (params_wrong_device);
424+ ASSERT_EQ (umf_result, UMF_RESULT_SUCCESS);
425+ }
426+
386427TEST_P (umfCUDAProviderTest, multiContext) {
387428 CUdevice device;
388429 int ret = get_cuda_device (&device);
0 commit comments