Skip to content

Commit 683921f

Browse files
author
Diptorup Deb
committed
Add a helper function to canonicalize device identifiers.
We were not handling the cases where a device identifier does not provide a backend or a device type. Changed the logic to count all devices of a given type if no backend is specified. Similarly, all devices in a backend are counted if no device type is provided.
1 parent 711f868 commit 683921f

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

dpctl-capi/source/dpctl_sycl_device_manager.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,39 @@ std::string get_device_info_str(const device &Device)
6666
return ss.str();
6767
}
6868

69+
/*!
70+
* @brief Canonicalizes a device identifier bit flag to have a valid (i.e., not
71+
* UNKNOWN) backend and device type bits.
72+
*
73+
* The device id is bit flag that indicates the backend and device type, both
74+
* of which are optional, that are to be queried. The function makes sure if a
75+
* device identifier only provides a device type value the backend is set to
76+
* DPCTL_ALL_BACKENDS. Similarly, if only backend is provided the device type
77+
* is set to DPCTL_ALL.
78+
*
79+
* @param device_id A bit flag storing a backend and a device type value.
80+
* @return Canonicalized bit flag that makes sure neither backend nor device
81+
* type is UNKNOWN (0). For cases where the input device id does not provide
82+
* either one of the values, we set the value to ALL.
83+
*/
84+
int to_canonical_device_id(int device_id)
85+
{ // If the identifier is 0 (UNKNOWN_DEVICE) return 0.
86+
if (!device_id)
87+
return 0;
88+
89+
// Check if the device identifier has a backend specified. If not, then
90+
// toggle all backend specifier bits, i.e. set the backend to
91+
// DPCTL_ALL_BACKENDS.
92+
if (!(device_id & DPCTL_ALL_BACKENDS))
93+
device_id |= DPCTL_ALL_BACKENDS;
94+
95+
// Check if a device type was specified. If not, set device type to ALL.
96+
if (!(device_id & ~DPCTL_ALL_BACKENDS))
97+
device_id |= DPCTL_ALL;
98+
99+
return device_id;
100+
}
101+
69102
struct DeviceCacheBuilder
70103
{
71104
using DeviceCache = std::unordered_map<device, context>;
@@ -146,12 +179,18 @@ DPCTLDeviceMgr_GetDevices(int device_identifier)
146179
{
147180
std::vector<DPCTLSyclDeviceRef> *Devices = nullptr;
148181

182+
device_identifier = to_canonical_device_id(device_identifier);
183+
149184
try {
150185
Devices = new std::vector<DPCTLSyclDeviceRef>();
151186
} catch (std::bad_alloc const &ba) {
152187
delete Devices;
153188
return nullptr;
154189
}
190+
191+
if (!device_identifier)
192+
return wrap(Devices);
193+
155194
const auto &root_devices = device::get_devices();
156195
default_selector mRanker;
157196

@@ -195,6 +234,10 @@ int DPCTLDeviceMgr_GetPositionInDevices(__dpctl_keep DPCTLSyclDeviceRef DRef,
195234
return not_found;
196235
}
197236

237+
device_identifier = to_canonical_device_id(device_identifier);
238+
if (!device_identifier)
239+
return not_found;
240+
198241
const auto &root_devices = device::get_devices();
199242
default_selector mRanker;
200243
int index = not_found;
@@ -224,6 +267,11 @@ size_t DPCTLDeviceMgr_GetNumDevices(int device_identifier)
224267
{
225268
size_t nDevices = 0;
226269
auto &cache = DeviceCacheBuilder::getDeviceCache();
270+
271+
device_identifier = to_canonical_device_id(device_identifier);
272+
if (!device_identifier)
273+
return 0;
274+
227275
for (const auto &entry : cache) {
228276
auto Bty(DPCTL_SyclBackendToDPCTLBackendType(
229277
entry.first.get_platform().get_backend()));

0 commit comments

Comments
 (0)