@@ -34,7 +34,17 @@ ur_result_t getProviderNativeError(const char *providerName,
3434}
3535} // namespace umf
3636
37- static usm::DisjointPoolAllConfigs initializeDisjointPoolConfig () {
37+ static std::optional<usm::DisjointPoolAllConfigs>
38+ initializeDisjointPoolConfig () {
39+ const char *UrRetDisable = std::getenv (" UR_L0_DISABLE_USM_ALLOCATOR" );
40+ const char *PiRetDisable =
41+ std::getenv (" SYCL_PI_LEVEL_ZERO_DISABLE_USM_ALLOCATOR" );
42+ const char *Disable =
43+ UrRetDisable ? UrRetDisable : (PiRetDisable ? PiRetDisable : nullptr );
44+ if (Disable != nullptr && Disable != std::string (" " )) {
45+ return std::nullopt ;
46+ }
47+
3848 const char *PoolUrTraceVal = std::getenv (" UR_L0_USM_ALLOCATOR_TRACE" );
3949
4050 int PoolTrace = 0 ;
@@ -47,7 +57,14 @@ static usm::DisjointPoolAllConfigs initializeDisjointPoolConfig() {
4757 return usm::DisjointPoolAllConfigs (PoolTrace);
4858 }
4959
50- return usm::parseDisjointPoolConfig (PoolUrConfigVal, PoolTrace);
60+ // TODO: rework parseDisjointPoolConfig to return optional,
61+ // once EnableBuffers is no longer used (by legacy L0)
62+ auto configs = usm::parseDisjointPoolConfig (PoolUrConfigVal, PoolTrace);
63+ if (configs.EnableBuffers ) {
64+ return configs;
65+ }
66+
67+ return std::nullopt ;
5168}
5269
5370inline umf_usm_memory_type_t urToUmfMemoryType (ur_usm_type_t type) {
@@ -81,32 +98,35 @@ descToDisjoinPoolMemType(const usm::pool_descriptor &desc) {
8198 }
8299}
83100
84- static umf::pool_unique_handle_t
85- makePool (usm::umf_disjoint_pool_config_t *poolParams,
86- usm::pool_descriptor poolDescriptor) {
87- umf_level_zero_memory_provider_params_handle_t params = NULL ;
88- umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate (¶ms);
101+ static umf::provider_unique_handle_t
102+ makeProvider (usm::pool_descriptor poolDescriptor) {
103+ umf_level_zero_memory_provider_params_handle_t hParams;
104+ umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate (&hParams);
89105 if (umf_ret != UMF_RESULT_SUCCESS) {
90106 throw umf::umf2urResult (umf_ret);
91107 }
92108
109+ std::unique_ptr<umf_level_zero_memory_provider_params_t ,
110+ decltype (&umfLevelZeroMemoryProviderParamsDestroy)>
111+ params (hParams, &umfLevelZeroMemoryProviderParamsDestroy);
112+
93113 umf_ret = umfLevelZeroMemoryProviderParamsSetContext (
94- params , poolDescriptor.hContext ->getZeHandle ());
114+ hParams , poolDescriptor.hContext ->getZeHandle ());
95115 if (umf_ret != UMF_RESULT_SUCCESS) {
96116 throw umf::umf2urResult (umf_ret);
97117 };
98118
99119 ze_device_handle_t level_zero_device_handle =
100120 poolDescriptor.hDevice ? poolDescriptor.hDevice ->ZeDevice : nullptr ;
101121
102- umf_ret = umfLevelZeroMemoryProviderParamsSetDevice (params ,
122+ umf_ret = umfLevelZeroMemoryProviderParamsSetDevice (hParams ,
103123 level_zero_device_handle);
104124 if (umf_ret != UMF_RESULT_SUCCESS) {
105125 throw umf::umf2urResult (umf_ret);
106126 }
107127
108128 umf_ret = umfLevelZeroMemoryProviderParamsSetMemoryType (
109- params , urToUmfMemoryType (poolDescriptor.type ));
129+ hParams , urToUmfMemoryType (poolDescriptor.type ));
110130 if (umf_ret != UMF_RESULT_SUCCESS) {
111131 throw umf::umf2urResult (umf_ret);
112132 }
@@ -123,46 +143,59 @@ makePool(usm::umf_disjoint_pool_config_t *poolParams,
123143 }
124144
125145 umf_ret = umfLevelZeroMemoryProviderParamsSetResidentDevices (
126- params , residentZeHandles.data (), residentZeHandles.size ());
146+ hParams , residentZeHandles.data (), residentZeHandles.size ());
127147 if (umf_ret != UMF_RESULT_SUCCESS) {
128148 throw umf::umf2urResult (umf_ret);
129149 }
130150 }
131151
132152 auto [ret, provider] =
133- umf::providerMakeUniqueFromOps (umfLevelZeroMemoryProviderOps (), params );
153+ umf::providerMakeUniqueFromOps (umfLevelZeroMemoryProviderOps (), hParams );
134154 if (ret != UMF_RESULT_SUCCESS) {
135155 throw umf::umf2urResult (ret);
136156 }
137157
138- if (!poolParams) {
139- auto [ret, poolHandle] = umf::poolMakeUniqueFromOps (
140- umfProxyPoolOps (), std::move (provider), nullptr );
141- if (ret != UMF_RESULT_SUCCESS)
142- throw umf::umf2urResult (ret);
143- return std::move (poolHandle);
144- } else {
145- auto umfParams = getUmfParamsHandle (*poolParams);
158+ return std::move (provider);
159+ }
146160
147- auto [ret, poolHandle] =
148- umf::poolMakeUniqueFromOps (umfDisjointPoolOps (), std::move (provider),
149- static_cast <void *>(umfParams.get ()));
150- if (ret != UMF_RESULT_SUCCESS)
151- throw umf::umf2urResult (ret);
152- return std::move (poolHandle);
153- }
161+ static umf::pool_unique_handle_t
162+ makeDisjointPool (umf::provider_unique_handle_t &&provider,
163+ usm::umf_disjoint_pool_config_t &poolParams) {
164+ auto umfParams = getUmfParamsHandle (poolParams);
165+ auto [ret, poolHandle] =
166+ umf::poolMakeUniqueFromOps (umfDisjointPoolOps (), std::move (provider),
167+ static_cast <void *>(umfParams.get ()));
168+ if (ret != UMF_RESULT_SUCCESS)
169+ throw umf::umf2urResult (ret);
170+ return std::move (poolHandle);
171+ }
172+
173+ static umf::pool_unique_handle_t
174+ makeProxyPool (umf::provider_unique_handle_t &&provider) {
175+ auto [ret, poolHandle] = umf::poolMakeUniqueFromOps (
176+ umfProxyPoolOps (), std::move (provider), nullptr );
177+ if (ret != UMF_RESULT_SUCCESS)
178+ throw umf::umf2urResult (ret);
179+
180+ return std::move (poolHandle);
154181}
155182
156183ur_usm_pool_handle_t_::ur_usm_pool_handle_t_ (ur_context_handle_t hContext,
157184 ur_usm_pool_desc_t *pPoolDesc)
158185 : hContext(hContext) {
159186 // TODO: handle UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK from pPoolDesc
160187 auto disjointPoolConfigs = initializeDisjointPoolConfig ();
161- if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t >(pPoolDesc)) {
162- for (auto &config : disjointPoolConfigs.Configs ) {
163- config.MaxPoolableSize = limits->maxPoolableSize ;
164- config.SlabMinSize = limits->minDriverAllocSize ;
188+
189+ if (disjointPoolConfigs.has_value ()) {
190+ if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t >(pPoolDesc)) {
191+ for (auto &config : disjointPoolConfigs.value ().Configs ) {
192+ config.MaxPoolableSize = limits->maxPoolableSize ;
193+ config.SlabMinSize = limits->minDriverAllocSize ;
194+ }
165195 }
196+ } else {
197+ // If pooling is disabled, do nothing.
198+ logger::info (" USM pooling is disabled. Skiping pool limits adjustment." );
166199 }
167200
168201 auto [result, descriptors] = usm::pool_descriptor::create (this , hContext);
@@ -171,12 +204,13 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t hContext,
171204 }
172205
173206 for (auto &desc : descriptors) {
174- if (disjointPoolConfigs.EnableBuffers ) {
207+ if (disjointPoolConfigs.has_value () ) {
175208 auto &poolConfig =
176- disjointPoolConfigs.Configs [descToDisjoinPoolMemType (desc)];
177- poolManager.addPool (desc, makePool (&poolConfig, desc));
209+ disjointPoolConfigs.value ().Configs [descToDisjoinPoolMemType (desc)];
210+ poolManager.addPool (desc,
211+ makeDisjointPool (makeProvider (desc), poolConfig));
178212 } else {
179- poolManager.addPool (desc, makePool ( nullptr , desc));
213+ poolManager.addPool (desc, makeProxyPool ( makeProvider ( desc) ));
180214 }
181215 }
182216}
0 commit comments