@@ -66,6 +66,39 @@ std::string get_device_info_str(const device &Device)
66
66
return ss.str ();
67
67
}
68
68
69
+ /* !
70
+ * @brief Canonicalizes a device identifier bit flag to have a valid (i.e., not
71
+ * UNKNOWN) backend and device type bits.
72
+ *
73
+ * The device id is bit flag that indicates the backend and device type, both
74
+ * of which are optional, that are to be queried. The function makes sure if a
75
+ * device identifier only provides a device type value the backend is set to
76
+ * DPCTL_ALL_BACKENDS. Similarly, if only backend is provided the device type
77
+ * is set to DPCTL_ALL.
78
+ *
79
+ * @param device_id A bit flag storing a backend and a device type value.
80
+ * @return Canonicalized bit flag that makes sure neither backend nor device
81
+ * type is UNKNOWN (0). For cases where the input device id does not provide
82
+ * either one of the values, we set the value to ALL.
83
+ */
84
+ int to_canonical_device_id (int device_id)
85
+ { // If the identifier is 0 (UNKNOWN_DEVICE) return 0.
86
+ if (!device_id)
87
+ return 0 ;
88
+
89
+ // Check if the device identifier has a backend specified. If not, then
90
+ // toggle all backend specifier bits, i.e. set the backend to
91
+ // DPCTL_ALL_BACKENDS.
92
+ if (!(device_id & DPCTL_ALL_BACKENDS))
93
+ device_id |= DPCTL_ALL_BACKENDS;
94
+
95
+ // Check if a device type was specified. If not, set device type to ALL.
96
+ if (!(device_id & ~DPCTL_ALL_BACKENDS))
97
+ device_id |= DPCTL_ALL;
98
+
99
+ return device_id;
100
+ }
101
+
69
102
struct DeviceCacheBuilder
70
103
{
71
104
using DeviceCache = std::unordered_map<device, context>;
@@ -146,12 +179,18 @@ DPCTLDeviceMgr_GetDevices(int device_identifier)
146
179
{
147
180
std::vector<DPCTLSyclDeviceRef> *Devices = nullptr ;
148
181
182
+ device_identifier = to_canonical_device_id (device_identifier);
183
+
149
184
try {
150
185
Devices = new std::vector<DPCTLSyclDeviceRef>();
151
186
} catch (std::bad_alloc const &ba) {
152
187
delete Devices;
153
188
return nullptr ;
154
189
}
190
+
191
+ if (!device_identifier)
192
+ return wrap (Devices);
193
+
155
194
const auto &root_devices = device::get_devices ();
156
195
default_selector mRanker ;
157
196
@@ -195,6 +234,10 @@ int DPCTLDeviceMgr_GetPositionInDevices(__dpctl_keep DPCTLSyclDeviceRef DRef,
195
234
return not_found;
196
235
}
197
236
237
+ device_identifier = to_canonical_device_id (device_identifier);
238
+ if (!device_identifier)
239
+ return not_found;
240
+
198
241
const auto &root_devices = device::get_devices ();
199
242
default_selector mRanker ;
200
243
int index = not_found;
@@ -224,6 +267,11 @@ size_t DPCTLDeviceMgr_GetNumDevices(int device_identifier)
224
267
{
225
268
size_t nDevices = 0 ;
226
269
auto &cache = DeviceCacheBuilder::getDeviceCache ();
270
+
271
+ device_identifier = to_canonical_device_id (device_identifier);
272
+ if (!device_identifier)
273
+ return 0 ;
274
+
227
275
for (const auto &entry : cache) {
228
276
auto Bty (DPCTL_SyclBackendToDPCTLBackendType (
229
277
entry.first .get_platform ().get_backend ()));
0 commit comments