1111#ifndef USM_POOL_MANAGER_HPP
1212#define USM_POOL_MANAGER_HPP 1
1313
14+ #include < ur_ddi.h>
15+
1416#include " logger/ur_logger.hpp"
1517#include " umf_helpers.hpp"
1618#include " ur_api.h"
2628
2729namespace usm {
2830
31+ namespace detail {
32+ struct ddiTables {
33+ ddiTables () {
34+ auto ret =
35+ urGetDeviceProcAddrTable (UR_API_VERSION_CURRENT, &deviceDdiTable);
36+ if (ret != UR_RESULT_SUCCESS) {
37+ throw ret;
38+ }
39+
40+ ret =
41+ urGetContextProcAddrTable (UR_API_VERSION_CURRENT, &contextDdiTable);
42+ if (ret != UR_RESULT_SUCCESS) {
43+ throw ret;
44+ }
45+ }
46+ ur_device_dditable_t deviceDdiTable;
47+ ur_context_dditable_t contextDdiTable;
48+ };
49+ } // namespace detail
50+
2951// / @brief describes an internal USM pool instance.
3052struct pool_descriptor {
3153 ur_usm_pool_handle_t poolHandle;
@@ -44,9 +66,12 @@ struct pool_descriptor {
4466
4567static inline std::pair<ur_result_t , std::vector<ur_device_handle_t >>
4668urGetSubDevices (ur_device_handle_t hDevice) {
69+ static detail::ddiTables ddi;
70+
4771 uint32_t nComputeUnits;
48- auto ret = urDeviceGetInfo (hDevice, UR_DEVICE_INFO_MAX_COMPUTE_UNITS,
49- sizeof (nComputeUnits), &nComputeUnits, nullptr );
72+ auto ret = ddi.deviceDdiTable .pfnGetInfo (
73+ hDevice, UR_DEVICE_INFO_MAX_COMPUTE_UNITS, sizeof (nComputeUnits),
74+ &nComputeUnits, nullptr );
5075 if (ret != UR_RESULT_SUCCESS) {
5176 return {ret, {}};
5277 }
@@ -64,15 +89,16 @@ urGetSubDevices(ur_device_handle_t hDevice) {
6489
6590 // Get the number of devices that will be created
6691 uint32_t deviceCount;
67- ret = urDevicePartition (hDevice, &properties, 0 , nullptr , &deviceCount);
92+ ret = ddi.deviceDdiTable .pfnPartition (hDevice, &properties, 0 , nullptr ,
93+ &deviceCount);
6894 if (ret != UR_RESULT_SUCCESS) {
6995 return {ret, {}};
7096 }
7197
7298 std::vector<ur_device_handle_t > sub_devices (deviceCount);
73- ret = urDevicePartition (hDevice, &properties,
74- static_cast <uint32_t >(sub_devices.size ()),
75- sub_devices.data (), nullptr );
99+ ret = ddi. deviceDdiTable . pfnPartition (
100+ hDevice, &properties, static_cast <uint32_t >(sub_devices.size ()),
101+ sub_devices.data (), nullptr );
76102 if (ret != UR_RESULT_SUCCESS) {
77103 return {ret, {}};
78104 }
@@ -82,17 +108,20 @@ urGetSubDevices(ur_device_handle_t hDevice) {
82108
83109inline std::pair<ur_result_t , std::vector<ur_device_handle_t >>
84110urGetAllDevicesAndSubDevices (ur_context_handle_t hContext) {
111+ static detail::ddiTables ddi;
112+
85113 size_t deviceCount = 0 ;
86- auto ret = urContextGetInfo (hContext, UR_CONTEXT_INFO_NUM_DEVICES,
87- sizeof (deviceCount), &deviceCount, nullptr );
114+ auto ret = ddi.contextDdiTable .pfnGetInfo (
115+ hContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof (deviceCount),
116+ &deviceCount, nullptr );
88117 if (ret != UR_RESULT_SUCCESS || deviceCount == 0 ) {
89118 return {ret, {}};
90119 }
91120
92121 std::vector<ur_device_handle_t > devices (deviceCount);
93- ret = urContextGetInfo (hContext, UR_CONTEXT_INFO_DEVICES,
94- sizeof ( ur_device_handle_t ) * deviceCount ,
95- devices.data (), nullptr );
122+ ret = ddi. contextDdiTable . pfnGetInfo (
123+ hContext, UR_CONTEXT_INFO_DEVICES ,
124+ sizeof ( ur_device_handle_t ) * deviceCount, devices.data (), nullptr );
96125 if (ret != UR_RESULT_SUCCESS) {
97126 return {ret, {}};
98127 }
@@ -135,6 +164,8 @@ isSharedAllocationReadOnlyOnDevice(const pool_descriptor &desc) {
135164}
136165
137166inline bool pool_descriptor::operator ==(const pool_descriptor &other) const {
167+ static usm::detail::ddiTables ddi;
168+
138169 const pool_descriptor &lhs = *this ;
139170 const pool_descriptor &rhs = other;
140171 ur_native_handle_t lhsNative = 0 , rhsNative = 0 ;
@@ -145,14 +176,16 @@ inline bool pool_descriptor::operator==(const pool_descriptor &other) const {
145176 // Ref: https://github.com/intel/llvm/commit/86511c5dc84b5781dcfd828caadcb5cac157eae1
146177 // TODO: is this L0 specific?
147178 if (lhs.hDevice ) {
148- auto ret = urDeviceGetNativeHandle (lhs.hDevice , &lhsNative);
179+ auto ret =
180+ ddi.deviceDdiTable .pfnGetNativeHandle (lhs.hDevice , &lhsNative);
149181 if (ret != UR_RESULT_SUCCESS) {
150182 throw ret;
151183 }
152184 }
153185
154186 if (rhs.hDevice ) {
155- auto ret = urDeviceGetNativeHandle (rhs.hDevice , &rhsNative);
187+ auto ret =
188+ ddi.deviceDdiTable .pfnGetNativeHandle (rhs.hDevice , &rhsNative);
156189 if (ret != UR_RESULT_SUCCESS) {
157190 throw ret;
158191 }
@@ -264,9 +297,12 @@ namespace std {
264297// / @brief hash specialization for usm::pool_descriptor
265298template <> struct hash <usm::pool_descriptor> {
266299 inline size_t operator ()(const usm::pool_descriptor &desc) const {
300+ static usm::detail::ddiTables ddi;
301+
267302 ur_native_handle_t native = 0 ;
268303 if (desc.hDevice ) {
269- auto ret = urDeviceGetNativeHandle (desc.hDevice , &native);
304+ auto ret =
305+ ddi.deviceDdiTable .pfnGetNativeHandle (desc.hDevice , &native);
270306 if (ret != UR_RESULT_SUCCESS) {
271307 throw ret;
272308 }
0 commit comments