@@ -61,104 +61,12 @@ struct pool_descriptor {
6161 bool operator ==(const pool_descriptor &other) const ;
6262 friend std::ostream &operator <<(std::ostream &os,
6363 const pool_descriptor &desc);
64- static std::pair<ur_result_t , std::vector<pool_descriptor>>
65- create (ur_usm_pool_handle_t poolHandle, ur_context_handle_t hContext);
64+ static std::vector<pool_descriptor>
65+ createFromDevices (ur_usm_pool_handle_t poolHandle,
66+ ur_context_handle_t hContext,
67+ const std::vector<ur_device_handle_t > &devices);
6668};
6769
68- static inline std::pair<ur_result_t , std::vector<ur_device_handle_t >>
69- urGetSubDevices (ur_device_handle_t hDevice) {
70- static detail::ddiTables ddi;
71-
72- uint32_t nComputeUnits;
73- auto ret = ddi.deviceDdiTable .pfnGetInfo (
74- hDevice, UR_DEVICE_INFO_MAX_COMPUTE_UNITS, sizeof (nComputeUnits),
75- &nComputeUnits, nullptr );
76- if (ret != UR_RESULT_SUCCESS) {
77- return {ret, {}};
78- }
79-
80- ur_device_partition_property_t prop;
81- prop.type = UR_DEVICE_PARTITION_BY_CSLICE;
82- prop.value .affinity_domain = 0 ;
83-
84- ur_device_partition_properties_t properties{
85- UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES,
86- nullptr ,
87- &prop,
88- 1 ,
89- };
90-
91- // Get the number of devices that will be created
92- uint32_t deviceCount;
93- ret = ddi.deviceDdiTable .pfnPartition (hDevice, &properties, 0 , nullptr ,
94- &deviceCount);
95- if (ret != UR_RESULT_SUCCESS) {
96- return {ret, {}};
97- }
98-
99- std::vector<ur_device_handle_t > sub_devices (deviceCount);
100- ret = ddi.deviceDdiTable .pfnPartition (
101- hDevice, &properties, static_cast <uint32_t >(sub_devices.size ()),
102- sub_devices.data (), nullptr );
103- if (ret != UR_RESULT_SUCCESS) {
104- return {ret, {}};
105- }
106-
107- return {UR_RESULT_SUCCESS, sub_devices};
108- }
109-
110- inline std::pair<ur_result_t , std::vector<ur_device_handle_t >>
111- urGetAllDevicesAndSubDevices (ur_context_handle_t hContext) {
112- static detail::ddiTables ddi;
113-
114- size_t deviceCount = 0 ;
115- auto ret = ddi.contextDdiTable .pfnGetInfo (
116- hContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof (deviceCount), &deviceCount,
117- nullptr );
118- if (ret != UR_RESULT_SUCCESS || deviceCount == 0 ) {
119- return {ret, {}};
120- }
121-
122- std::vector<ur_device_handle_t > devices (deviceCount);
123- ret = ddi.contextDdiTable .pfnGetInfo (hContext, UR_CONTEXT_INFO_DEVICES,
124- sizeof (ur_device_handle_t ) * deviceCount,
125- devices.data (), nullptr );
126- if (ret != UR_RESULT_SUCCESS) {
127- return {ret, {}};
128- }
129-
130- std::vector<ur_device_handle_t > devicesAndSubDevices;
131- std::function<ur_result_t (ur_device_handle_t )> addPoolsForDevicesRec =
132- [&](ur_device_handle_t hDevice) {
133- devicesAndSubDevices.push_back (hDevice);
134- auto [ret, subDevices] = urGetSubDevices (hDevice);
135- if (ret != UR_RESULT_SUCCESS) {
136- return ret;
137- }
138- for (auto &subDevice : subDevices) {
139- ret = addPoolsForDevicesRec (subDevice);
140- if (ret != UR_RESULT_SUCCESS) {
141- return ret;
142- }
143- }
144- return UR_RESULT_SUCCESS;
145- };
146-
147- for (size_t i = 0 ; i < deviceCount; i++) {
148- ret = addPoolsForDevicesRec (devices[i]);
149- if (ret != UR_RESULT_SUCCESS) {
150- if (ret == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
151- // Return main devices when sub-devices are unsupported.
152- return {ret, std::move (devices)};
153- }
154-
155- return {ret, {}};
156- }
157- }
158-
159- return {UR_RESULT_SUCCESS, devicesAndSubDevices};
160- }
161-
16270static inline bool
16371isSharedAllocationReadOnlyOnDevice (const pool_descriptor &desc) {
16472 return desc.type == UR_USM_TYPE_SHARED && desc.deviceReadOnly ;
@@ -205,14 +113,9 @@ inline std::ostream &operator<<(std::ostream &os, const pool_descriptor &desc) {
205113 return os;
206114}
207115
208- inline std::pair<ur_result_t , std::vector<pool_descriptor>>
209- pool_descriptor::create (ur_usm_pool_handle_t poolHandle,
210- ur_context_handle_t hContext) {
211- auto [ret, devices] = urGetAllDevicesAndSubDevices (hContext);
212- if (ret != UR_RESULT_SUCCESS) {
213- return {ret, {}};
214- }
215-
116+ inline std::vector<pool_descriptor> pool_descriptor::createFromDevices (
117+ ur_usm_pool_handle_t poolHandle, ur_context_handle_t hContext,
118+ const std::vector<ur_device_handle_t > &devices) {
216119 std::vector<pool_descriptor> descriptors;
217120 pool_descriptor &desc = descriptors.emplace_back ();
218121 desc.poolHandle = poolHandle;
@@ -245,7 +148,7 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
245148 }
246149 }
247150
248- return {ret, descriptors} ;
151+ return descriptors;
249152}
250153
251154template <typename D> struct pool_manager {
0 commit comments