Skip to content

Commit 3a49cca

Browse files
[UMF] Update Disjoint Pool config params
1 parent de70a26 commit 3a49cca

File tree

6 files changed

+242
-103
lines changed

6 files changed

+242
-103
lines changed

source/adapters/cuda/usm.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,20 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
387387
const ur_usm_pool_limits_desc_t *Limits =
388388
reinterpret_cast<const ur_usm_pool_limits_desc_t *>(BaseDesc);
389389
for (auto &config : DisjointPoolConfigs.Configs) {
390-
config.MaxPoolableSize = Limits->maxPoolableSize;
391-
config.SlabMinSize = Limits->minDriverAllocSize;
390+
umf_result_t umf_ret = umfDisjointPoolParamsSetMaxPoolableSize(
391+
config, Limits->maxPoolableSize);
392+
if (umf_ret != UMF_RESULT_SUCCESS) {
393+
logger::error("urUSMPoolHandle: setting maxPoolableSize in "
394+
"DisjointPool params failed");
395+
throw umf::umf2urResult(umf_ret);
396+
}
397+
umf_ret = umfDisjointPoolParamsSetSlabMinSize(
398+
config, Limits->minDriverAllocSize);
399+
if (umf_ret != UMF_RESULT_SUCCESS) {
400+
logger::error("urUSMPoolHandle: setting slabMinSize in DisjointPool "
401+
"params failed");
402+
throw umf::umf2urResult(umf_ret);
403+
}
392404
}
393405
break;
394406
}

source/adapters/hip/usm.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,20 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
333333
if (PoolDesc) {
334334
if (auto *Limits = find_stype_node<ur_usm_pool_limits_desc_t>(PoolDesc)) {
335335
for (auto &config : DisjointPoolConfigs.Configs) {
336-
config.MaxPoolableSize = Limits->maxPoolableSize;
337-
config.SlabMinSize = Limits->minDriverAllocSize;
336+
umf_result_t umf_ret = umfDisjointPoolParamsSetMaxPoolableSize(
337+
config, Limits->maxPoolableSize);
338+
if (umf_ret != UMF_RESULT_SUCCESS) {
339+
logger::error("urUSMPoolHandle: setting maxPoolableSize in "
340+
"DisjointPool params failed");
341+
throw umf::umf2urResult(umf_ret);
342+
}
343+
umf_ret = umfDisjointPoolParamsSetSlabMinSize(
344+
config, Limits->minDriverAllocSize);
345+
if (umf_ret != UMF_RESULT_SUCCESS) {
346+
logger::error("urUSMPoolHandle: setting slabMinSize in DisjointPool "
347+
"params failed");
348+
throw umf::umf2urResult(umf_ret);
349+
}
338350
}
339351
} else {
340352
throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT);

source/adapters/level_zero/usm.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,8 +1034,18 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
10341034
const ur_usm_pool_limits_desc_t *Limits =
10351035
reinterpret_cast<const ur_usm_pool_limits_desc_t *>(BaseDesc);
10361036
for (auto &config : DisjointPoolConfigs.Configs) {
1037-
config.MaxPoolableSize = Limits->maxPoolableSize;
1038-
config.SlabMinSize = Limits->minDriverAllocSize;
1037+
if (umfDisjointPoolParamsSetMaxPoolableSize(
1038+
config, Limits->maxPoolableSize) != UMF_RESULT_SUCCESS) {
1039+
logger::error("urUSMPoolCreate: setting maxPoolableSize in "
1040+
"DisjointPool params failed");
1041+
throw UsmAllocationException(UR_RESULT_ERROR_UNKNOWN);
1042+
}
1043+
if (umfDisjointPoolParamsSetSlabMinSize(
1044+
config, Limits->minDriverAllocSize) != UMF_RESULT_SUCCESS) {
1045+
logger::error("urUSMPoolCreate: setting slabMinSize in DisjointPool "
1046+
"params failed");
1047+
throw UsmAllocationException(UR_RESULT_ERROR_UNKNOWN);
1048+
}
10391049
}
10401050
break;
10411051
}

source/adapters/level_zero/v2/usm.cpp

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,27 +82,50 @@ descToDisjoinPoolMemType(const usm::pool_descriptor &desc) {
8282
}
8383

8484
static umf::pool_unique_handle_t
85-
makePool(umf_disjoint_pool_params_t *poolParams,
85+
makePool(umf_disjoint_pool_params_handle_t *poolParams,
8686
usm::pool_descriptor poolDescriptor) {
87-
level_zero_memory_provider_params_t params = {};
88-
params.level_zero_context_handle = poolDescriptor.hContext->getZeHandle();
89-
params.level_zero_device_handle =
87+
umf_level_zero_memory_provider_params_handle_t params = NULL;
88+
umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate(&params);
89+
if (umf_ret != UMF_RESULT_SUCCESS) {
90+
throw umf::umf2urResult(umf_ret);
91+
}
92+
93+
umf_ret = umfLevelZeroMemoryProviderParamsSetContext(
94+
params, poolDescriptor.hContext->getZeHandle());
95+
if (umf_ret != UMF_RESULT_SUCCESS) {
96+
throw umf::umf2urResult(umf_ret);
97+
}
98+
99+
ze_device_handle_t level_zero_device_handle =
90100
poolDescriptor.hDevice ? poolDescriptor.hDevice->ZeDevice : nullptr;
91-
params.memory_type = urToUmfMemoryType(poolDescriptor.type);
101+
umf_ret = umfLevelZeroMemoryProviderParamsSetDevice(params,
102+
level_zero_device_handle);
103+
if (umf_ret != UMF_RESULT_SUCCESS) {
104+
throw umf::umf2urResult(umf_ret);
105+
}
106+
107+
umf_ret = umfLevelZeroMemoryProviderParamsSetMemoryType(
108+
params, urToUmfMemoryType(poolDescriptor.type));
109+
if (umf_ret != UMF_RESULT_SUCCESS) {
110+
throw umf::umf2urResult(umf_ret);
111+
}
92112

93113
std::vector<ze_device_handle_t> residentZeHandles;
94114

95115
if (poolDescriptor.type == UR_USM_TYPE_DEVICE) {
96-
assert(params.level_zero_device_handle);
116+
assert(level_zero_device_handle);
97117
auto residentHandles =
98118
poolDescriptor.hContext->getP2PDevices(poolDescriptor.hDevice);
99-
residentZeHandles.push_back(params.level_zero_device_handle);
119+
residentZeHandles.push_back(level_zero_device_handle);
100120
for (auto &device : residentHandles) {
101121
residentZeHandles.push_back(device->ZeDevice);
102122
}
103123

104-
params.resident_device_handles = residentZeHandles.data();
105-
params.resident_device_count = residentZeHandles.size();
124+
umf_ret = umfLevelZeroMemoryProviderParamsSetResidentDevices(
125+
params, residentZeHandles.data(), residentZeHandles.size());
126+
if (umf_ret != UMF_RESULT_SUCCESS) {
127+
throw umf::umf2urResult(umf_ret);
128+
}
106129
}
107130

108131
auto [ret, provider] =
@@ -134,8 +157,20 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t hContext,
134157
auto disjointPoolConfigs = initializeDisjointPoolConfig();
135158
if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t>(pPoolDesc)) {
136159
for (auto &config : disjointPoolConfigs.Configs) {
137-
config.MaxPoolableSize = limits->maxPoolableSize;
138-
config.SlabMinSize = limits->minDriverAllocSize;
160+
umf_result_t umf_ret = umfDisjointPoolParamsSetMaxPoolableSize(
161+
config, limits->maxPoolableSize);
162+
if (umf_ret != UMF_RESULT_SUCCESS) {
163+
logger::error("urUSMPoolHandle: setting maxPoolableSize in "
164+
"DisjointPool params failed");
165+
throw umf::umf2urResult(umf_ret);
166+
}
167+
umf_ret = umfDisjointPoolParamsSetSlabMinSize(config,
168+
limits->minDriverAllocSize);
169+
if (umf_ret != UMF_RESULT_SUCCESS) {
170+
logger::error("urUSMPoolHandle: setting slabMinSize in DisjointPool "
171+
"params failed");
172+
throw umf::umf2urResult(umf_ret);
173+
}
139174
}
140175
}
141176

0 commit comments

Comments
 (0)