@@ -82,31 +82,55 @@ descToDisjoinPoolMemType(const usm::pool_descriptor &desc) {
8282}
8383
8484static umf::pool_unique_handle_t
85- makePool (umf_disjoint_pool_params_t *poolParams,
85+ makePool (usm:: umf_disjoint_pool_config_t *poolParams,
8686 usm::pool_descriptor poolDescriptor) {
87- level_zero_memory_provider_params_t params = {};
88- params.level_zero_context_handle = poolDescriptor.hContext ->getZeHandle ();
89- params.level_zero_device_handle =
87+ umf_level_zero_memory_provider_params_handle_t params = NULL ;
88+ umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate (¶ms);
89+ if (umf_ret != UMF_RESULT_SUCCESS) {
90+ throw umf::umf2urResult (umf_ret);
91+ }
92+
93+ umf_ret = umfLevelZeroMemoryProviderParamsSetContext (
94+ params, poolDescriptor.hContext ->getZeHandle ());
95+ if (umf_ret != UMF_RESULT_SUCCESS) {
96+ throw umf::umf2urResult (umf_ret);
97+ };
98+
99+ ze_device_handle_t level_zero_device_handle =
90100 poolDescriptor.hDevice ? poolDescriptor.hDevice ->ZeDevice : nullptr ;
91- params.memory_type = urToUmfMemoryType (poolDescriptor.type );
101+
102+ umf_ret = umfLevelZeroMemoryProviderParamsSetDevice (params,
103+ level_zero_device_handle);
104+ if (umf_ret != UMF_RESULT_SUCCESS) {
105+ throw umf::umf2urResult (umf_ret);
106+ }
107+
108+ umf_ret = umfLevelZeroMemoryProviderParamsSetMemoryType (
109+ params, urToUmfMemoryType (poolDescriptor.type ));
110+ if (umf_ret != UMF_RESULT_SUCCESS) {
111+ throw umf::umf2urResult (umf_ret);
112+ }
92113
93114 std::vector<ze_device_handle_t > residentZeHandles;
94115
95116 if (poolDescriptor.type == UR_USM_TYPE_DEVICE) {
96- assert (params. level_zero_device_handle );
117+ assert (level_zero_device_handle);
97118 auto residentHandles =
98119 poolDescriptor.hContext ->getP2PDevices (poolDescriptor.hDevice );
99- residentZeHandles.push_back (params. level_zero_device_handle );
120+ residentZeHandles.push_back (level_zero_device_handle);
100121 for (auto &device : residentHandles) {
101122 residentZeHandles.push_back (device->ZeDevice );
102123 }
103124
104- params.resident_device_handles = residentZeHandles.data ();
105- params.resident_device_count = residentZeHandles.size ();
125+ umf_ret = umfLevelZeroMemoryProviderParamsSetResidentDevices (
126+ params, residentZeHandles.data (), residentZeHandles.size ());
127+ if (umf_ret != UMF_RESULT_SUCCESS) {
128+ throw umf::umf2urResult (umf_ret);
129+ }
106130 }
107131
108132 auto [ret, provider] =
109- umf::providerMakeUniqueFromOps (umfLevelZeroMemoryProviderOps (), & params);
133+ umf::providerMakeUniqueFromOps (umfLevelZeroMemoryProviderOps (), params);
110134 if (ret != UMF_RESULT_SUCCESS) {
111135 throw umf::umf2urResult (ret);
112136 }
@@ -118,9 +142,11 @@ makePool(umf_disjoint_pool_params_t *poolParams,
118142 throw umf::umf2urResult (ret);
119143 return std::move (poolHandle);
120144 } else {
145+ auto umfParams = getUmfParamsHandle (*poolParams);
146+
121147 auto [ret, poolHandle] =
122148 umf::poolMakeUniqueFromOps (umfDisjointPoolOps (), std::move (provider),
123- static_cast <void *>(poolParams ));
149+ static_cast <void *>(umfParams. get () ));
124150 if (ret != UMF_RESULT_SUCCESS)
125151 throw umf::umf2urResult (ret);
126152 return std::move (poolHandle);
@@ -199,10 +225,13 @@ ur_result_t urUSMPoolCreate(
199225 pPoolDesc, // /< [in] pointer to USM pool descriptor. Can be chained with
200226 // /< ::ur_usm_pool_limits_desc_t
201227 ur_usm_pool_handle_t *hPool // /< [out] pointer to USM memory pool
202- ) {
203-
228+ ) try {
204229 *hPool = new ur_usm_pool_handle_t_ (hContext, pPoolDesc);
205230 return UR_RESULT_SUCCESS;
231+ } catch (umf_result_t e) {
232+ return umf::umf2urResult (e);
233+ } catch (...) {
234+ return exceptionToResult (std::current_exception ());
206235}
207236
208237ur_result_t
0 commit comments