@@ -82,27 +82,50 @@ descToDisjoinPoolMemType(const usm::pool_descriptor &desc) {
8282}
8383
8484static umf::pool_unique_handle_t
85- makePool (umf_disjoint_pool_params_t *poolParams,
85+ makePool (umf_disjoint_pool_params_handle_t *poolParams,
8686 usm::pool_descriptor poolDescriptor) {
87- level_zero_memory_provider_params_t params = {};
88- params.level_zero_context_handle = poolDescriptor.hContext ->getZeHandle ();
89- params.level_zero_device_handle =
87+ umf_level_zero_memory_provider_params_handle_t params = NULL ;
88+ umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate (¶ms);
89+ if (umf_ret != UMF_RESULT_SUCCESS) {
90+ throw umf::umf2urResult (umf_ret);
91+ }
92+
93+ umf_ret = umfLevelZeroMemoryProviderParamsSetContext (
94+ params, poolDescriptor.hContext ->getZeHandle ());
95+ if (umf_ret != UMF_RESULT_SUCCESS) {
96+ throw umf::umf2urResult (umf_ret);
97+ }
98+
99+ ze_device_handle_t level_zero_device_handle =
90100 poolDescriptor.hDevice ? poolDescriptor.hDevice ->ZeDevice : nullptr ;
91- params.memory_type = urToUmfMemoryType (poolDescriptor.type );
101+ umf_ret = umfLevelZeroMemoryProviderParamsSetDevice (params,
102+ level_zero_device_handle);
103+ if (umf_ret != UMF_RESULT_SUCCESS) {
104+ throw umf::umf2urResult (umf_ret);
105+ }
106+
107+ umf_ret = umfLevelZeroMemoryProviderParamsSetMemoryType (
108+ params, urToUmfMemoryType (poolDescriptor.type ));
109+ if (umf_ret != UMF_RESULT_SUCCESS) {
110+ throw umf::umf2urResult (umf_ret);
111+ }
92112
93113 std::vector<ze_device_handle_t > residentZeHandles;
94114
95115 if (poolDescriptor.type == UR_USM_TYPE_DEVICE) {
96- assert (params. level_zero_device_handle );
116+ assert (level_zero_device_handle);
97117 auto residentHandles =
98118 poolDescriptor.hContext ->getP2PDevices (poolDescriptor.hDevice );
99- residentZeHandles.push_back (params. level_zero_device_handle );
119+ residentZeHandles.push_back (level_zero_device_handle);
100120 for (auto &device : residentHandles) {
101121 residentZeHandles.push_back (device->ZeDevice );
102122 }
103123
104- params.resident_device_handles = residentZeHandles.data ();
105- params.resident_device_count = residentZeHandles.size ();
124+ umf_ret = umfLevelZeroMemoryProviderParamsSetResidentDevices (
125+ params, residentZeHandles.data (), residentZeHandles.size ());
126+ if (umf_ret != UMF_RESULT_SUCCESS) {
127+ throw umf::umf2urResult (umf_ret);
128+ }
106129 }
107130
108131 auto [ret, provider] =
@@ -134,8 +157,20 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t hContext,
134157 auto disjointPoolConfigs = initializeDisjointPoolConfig ();
135158 if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t >(pPoolDesc)) {
136159 for (auto &config : disjointPoolConfigs.Configs ) {
137- config.MaxPoolableSize = limits->maxPoolableSize ;
138- config.SlabMinSize = limits->minDriverAllocSize ;
160+ umf_result_t umf_ret = umfDisjointPoolParamsSetMaxPoolableSize (
161+ config, limits->maxPoolableSize );
162+ if (umf_ret != UMF_RESULT_SUCCESS) {
163+ logger::error (" urUSMPoolHandle: setting maxPoolableSize in "
164+ " DisjointPool params failed" );
165+ throw umf::umf2urResult (umf_ret);
166+ }
167+ umf_ret = umfDisjointPoolParamsSetSlabMinSize (config,
168+ limits->minDriverAllocSize );
169+ if (umf_ret != UMF_RESULT_SUCCESS) {
170+ logger::error (" urUSMPoolHandle: setting slabMinSize in DisjointPool "
171+ " params failed" );
172+ throw umf::umf2urResult (umf_ret);
173+ }
139174 }
140175 }
141176
0 commit comments