@@ -82,27 +82,50 @@ descToDisjoinPoolMemType(const usm::pool_descriptor &desc) {
8282}
8383
8484static umf::pool_unique_handle_t
85- makePool (umf_disjoint_pool_params_t *poolParams,
85+ makePool (umf_disjoint_pool_params_handle_t *poolParams,
8686 usm::pool_descriptor poolDescriptor) {
87- level_zero_memory_provider_params_t params = {};
88- params.level_zero_context_handle = poolDescriptor.hContext ->getZeHandle ();
89- params.level_zero_device_handle =
87+ umf_level_zero_memory_provider_params_handle_t params = NULL ;
88+ umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate (¶ms);
89+ if (umf_ret != UMF_RESULT_SUCCESS) {
90+ throw umf::umf2urResult (umf_ret);
91+ }
92+
93+ umf_ret = umfLevelZeroMemoryProviderParamsSetContext (
94+ params, poolDescriptor.hContext ->getZeHandle ());
95+ if (umf_ret != UMF_RESULT_SUCCESS) {
96+ throw umf::umf2urResult (umf_ret);
97+ }
98+
99+ ze_device_handle_t level_zero_device_handle =
90100 poolDescriptor.hDevice ? poolDescriptor.hDevice ->ZeDevice : nullptr ;
91- params.memory_type = urToUmfMemoryType (poolDescriptor.type );
101+ umf_ret = umfLevelZeroMemoryProviderParamsSetDevice (params,
102+ level_zero_device_handle);
103+ if (umf_ret != UMF_RESULT_SUCCESS) {
104+ throw umf::umf2urResult (umf_ret);
105+ }
106+
107+ umf_ret = umfLevelZeroMemoryProviderParamsSetMemoryType (
108+ params, urToUmfMemoryType (poolDescriptor.type ));
109+ if (umf_ret != UMF_RESULT_SUCCESS) {
110+ throw umf::umf2urResult (umf_ret);
111+ }
92112
93113 std::vector<ze_device_handle_t > residentZeHandles;
94114
95115 if (poolDescriptor.type == UR_USM_TYPE_DEVICE) {
96- assert (params. level_zero_device_handle );
116+ assert (level_zero_device_handle);
97117 auto residentHandles =
98118 poolDescriptor.hContext ->getP2PDevices (poolDescriptor.hDevice );
99- residentZeHandles.push_back (params. level_zero_device_handle );
119+ residentZeHandles.push_back (level_zero_device_handle);
100120 for (auto &device : residentHandles) {
101121 residentZeHandles.push_back (device->ZeDevice );
102122 }
103123
104- params.resident_device_handles = residentZeHandles.data ();
105- params.resident_device_count = residentZeHandles.size ();
124+ umf_ret = umfLevelZeroMemoryProviderParamsSetResidentDevices (
125+ params, residentZeHandles.data (), residentZeHandles.size ());
126+ if (umf_ret != UMF_RESULT_SUCCESS) {
127+ throw umf::umf2urResult (umf_ret);
128+ }
106129 }
107130
108131 auto [ret, provider] =
@@ -134,8 +157,18 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t hContext,
134157 auto disjointPoolConfigs = initializeDisjointPoolConfig ();
135158 if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t >(pPoolDesc)) {
136159 for (auto &config : disjointPoolConfigs.Configs ) {
137- config.MaxPoolableSize = limits->maxPoolableSize ;
138- config.SlabMinSize = limits->minDriverAllocSize ;
160+ if (umfDisjointPoolParamsSetMaxPoolableSize (
161+ config, limits->maxPoolableSize ) != UMF_RESULT_SUCCESS) {
162+ logger::error (" urUSMPoolHandle: setting maxPoolableSize in "
163+ " DisjointPool params failed" );
164+ throw UR_RESULT_ERROR_UNKNOWN;
165+ }
166+ if (umfDisjointPoolParamsSetSlabMinSize (
167+ config, limits->minDriverAllocSize ) != UMF_RESULT_SUCCESS) {
168+ logger::error (" urUSMPoolHandle: setting slabMinSize in DisjointPool "
169+ " params failed" );
170+ throw UR_RESULT_ERROR_UNKNOWN;
171+ }
139172 }
140173 }
141174
0 commit comments