Skip to content

Commit faafe03

Browse files
committed
[SYCL] Fix platform::get_devices
In relation to the clarification in KhronosGroup/SYCL-Docs#861, the implementation of `platform::get_devices` does not work correctly for `custom` and `automatic`. This commit fixes this by ensuring these return the devices with custom type and the default device respectively. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent a69ec83 commit faafe03

File tree

11 files changed

+242
-57
lines changed

11 files changed

+242
-57
lines changed

sycl/source/detail/allowlist.cpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -399,22 +399,25 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
399399
Device, UR_DEVICE_INFO_TYPE, sizeof(UrDevType), &UrDevType, nullptr);
400400
// TODO need mechanism to do these casts, there's a bunch of this sort of
401401
// thing
402-
sycl::info::device_type DeviceType = info::device_type::all;
403-
switch (UrDevType) {
404-
default:
405-
case UR_DEVICE_TYPE_ALL:
406-
DeviceType = info::device_type::all;
407-
break;
408-
case UR_DEVICE_TYPE_GPU:
409-
DeviceType = info::device_type::gpu;
410-
break;
411-
case UR_DEVICE_TYPE_CPU:
412-
DeviceType = info::device_type::cpu;
413-
break;
414-
case UR_DEVICE_TYPE_FPGA:
415-
DeviceType = info::device_type::accelerator;
416-
break;
417-
}
402+
sycl::info::device_type DeviceType = [UrDevType]() {
403+
switch (UrDevType) {
404+
default:
405+
case UR_DEVICE_TYPE_ALL:
406+
return info::device_type::all;
407+
case UR_DEVICE_TYPE_GPU:
408+
return info::device_type::gpu;
409+
case UR_DEVICE_TYPE_CPU:
410+
return info::device_type::cpu;
411+
case UR_DEVICE_TYPE_FPGA:
412+
return info::device_type::accelerator;
413+
case UR_DEVICE_TYPE_CUSTOM:
414+
case UR_DEVICE_TYPE_MCA:
415+
case UR_DEVICE_TYPE_VPU:
416+
return info::device_type::custom;
417+
case UR_DEVICE_TYPE_DEFAULT:
418+
return info::device_type::automatic;
419+
}
420+
}();
418421
for (const auto &SyclDeviceType :
419422
getSyclDeviceTypeMap<true /*Enable 'acc'*/>()) {
420423
if (SyclDeviceType.second == DeviceType) {

sycl/source/detail/device_impl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
596596
return device_type::accelerator;
597597
case UR_DEVICE_TYPE_MCA:
598598
case UR_DEVICE_TYPE_VPU:
599+
case UR_DEVICE_TYPE_CUSTOM:
599600
return device_type::custom;
600601
default: {
601602
assert(false);

sycl/source/detail/platform_impl.cpp

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -248,24 +248,25 @@ platform_impl::filterDeviceFilter(std::vector<ur_device_handle_t> &UrDevices,
248248
MAdapter->call<UrApiKind::urDeviceGetInfo>(Device, UR_DEVICE_INFO_TYPE,
249249
sizeof(ur_device_type_t),
250250
&UrDevType, nullptr);
251-
// Assumption here is that there is 1-to-1 mapping between UrDevType and
252-
// Sycl device type for GPU, CPU, and ACC.
253-
info::device_type DeviceType = info::device_type::all;
254-
switch (UrDevType) {
255-
default:
256-
case UR_DEVICE_TYPE_ALL:
257-
DeviceType = info::device_type::all;
258-
break;
259-
case UR_DEVICE_TYPE_GPU:
260-
DeviceType = info::device_type::gpu;
261-
break;
262-
case UR_DEVICE_TYPE_CPU:
263-
DeviceType = info::device_type::cpu;
264-
break;
265-
case UR_DEVICE_TYPE_FPGA:
266-
DeviceType = info::device_type::accelerator;
267-
break;
268-
}
251+
info::device_type DeviceType = [UrDevType]() {
252+
switch (UrDevType) {
253+
default:
254+
case UR_DEVICE_TYPE_ALL:
255+
return info::device_type::all;
256+
case UR_DEVICE_TYPE_GPU:
257+
return info::device_type::gpu;
258+
case UR_DEVICE_TYPE_CPU:
259+
return info::device_type::cpu;
260+
case UR_DEVICE_TYPE_FPGA:
261+
return info::device_type::accelerator;
262+
case UR_DEVICE_TYPE_CUSTOM:
263+
case UR_DEVICE_TYPE_MCA:
264+
case UR_DEVICE_TYPE_VPU:
265+
return info::device_type::custom;
266+
case UR_DEVICE_TYPE_DEFAULT:
267+
return info::device_type::automatic;
268+
}
269+
}();
269270

270271
for (const FilterT &Filter : FilterList->get()) {
271272
backend FilterBackend = Filter.Backend.value_or(backend::all);
@@ -469,34 +470,57 @@ static std::vector<device> amendDeviceAndSubDevices(
469470
std::vector<device>
470471
platform_impl::get_devices(info::device_type DeviceType) const {
471472
std::vector<device> Res;
472-
473-
ods_target_list *OdsTargetList = SYCLConfig<ONEAPI_DEVICE_SELECTOR>::get();
473+
// Host is no longer supported, so it returns an empty vector.
474474
if (DeviceType == info::device_type::host)
475+
return std::vector<device>{};
476+
477+
// For custom devices, UR has additional type enums.
478+
if (DeviceType == info::device_type::custom) {
479+
getDevicesImplHelper(UR_DEVICE_TYPE_CUSTOM, Res);
480+
getDevicesImplHelper(UR_DEVICE_TYPE_MCA, Res);
481+
getDevicesImplHelper(UR_DEVICE_TYPE_VPU, Res);
482+
// Some backends may return the MCA and VPU types as part of custom, so
483+
// remove duplicates.
484+
std::sort(Res.begin(), Res.end(),
485+
[](const sycl::device &D1, const sycl::device &D2) {
486+
std::hash<sycl::device> Hasher;
487+
return Hasher(D1) < Hasher(D2);
488+
});
489+
auto NewEnd = std::unique(Res.begin(), Res.end());
490+
Res.erase(NewEnd, Res.end());
475491
return Res;
492+
}
476493

477-
ur_device_type_t UrDeviceType = UR_DEVICE_TYPE_ALL;
478-
494+
ur_device_type_t UrDeviceType = [DeviceType]() {
479495
switch (DeviceType) {
480-
default:
481496
case info::device_type::all:
482-
UrDeviceType = UR_DEVICE_TYPE_ALL;
483-
break;
497+
return UR_DEVICE_TYPE_ALL;
484498
case info::device_type::gpu:
485-
UrDeviceType = UR_DEVICE_TYPE_GPU;
486-
break;
499+
return UR_DEVICE_TYPE_GPU;
487500
case info::device_type::cpu:
488-
UrDeviceType = UR_DEVICE_TYPE_CPU;
489-
break;
501+
return UR_DEVICE_TYPE_CPU;
490502
case info::device_type::accelerator:
491-
UrDeviceType = UR_DEVICE_TYPE_FPGA;
492-
break;
503+
return UR_DEVICE_TYPE_FPGA;
504+
case info::device_type::automatic:
505+
return UR_DEVICE_TYPE_DEFAULT;
506+
default:
507+
throw sycl::exception(sycl::make_error_code(sycl::errc::invalid),
508+
"Unknown device type.");
493509
}
510+
}();
511+
getDevicesImplHelper(UrDeviceType, Res);
512+
return Res;
513+
}
514+
515+
void
516+
platform_impl::getDevicesImplHelper(ur_device_type_t UrDeviceType,
517+
std::vector<device> &OutVec) const {
518+
size_t InitialOutVecSize = OutVec.size();
494519

495520
uint32_t NumDevices = 0;
496521
MAdapter->call<UrApiKind::urDeviceGet>(MPlatform, UrDeviceType,
497522
0u, // CP info::device_type::all
498523
nullptr, &NumDevices);
499-
const backend Backend = getBackend();
500524

501525
if (NumDevices == 0) {
502526
// If platform doesn't have devices (even without filter)
@@ -514,7 +538,7 @@ platform_impl::get_devices(info::device_type DeviceType) const {
514538
std::lock_guard<std::mutex> Guard(*Adapter->getAdapterMutex());
515539
Adapter->adjustLastDeviceId(MPlatform);
516540
}
517-
return Res;
541+
return;
518542
}
519543

520544
std::vector<ur_device_handle_t> UrDevices(NumDevices);
@@ -532,6 +556,8 @@ platform_impl::get_devices(info::device_type DeviceType) const {
532556
if (SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get())
533557
applyAllowList(UrDevices, MPlatform, *MAdapter);
534558

559+
ods_target_list *OdsTargetList = SYCLConfig<ONEAPI_DEVICE_SELECTOR>::get();
560+
535561
// The first step is to filter out devices that are not compatible with
536562
// ONEAPI_DEVICE_SELECTOR. This is also the mechanism by which top level
537563
// device ids are assigned.
@@ -544,7 +570,7 @@ platform_impl::get_devices(info::device_type DeviceType) const {
544570
// The next step is to inflate the filtered UrDevices into SYCL Device
545571
// objects.
546572
platform_impl &PlatformImpl = getOrMakePlatformImpl(MPlatform, *MAdapter);
547-
std::transform(UrDevices.begin(), UrDevices.end(), std::back_inserter(Res),
573+
std::transform(UrDevices.begin(), UrDevices.end(), std::back_inserter(OutVec),
548574
[&PlatformImpl](const ur_device_handle_t UrDevice) -> device {
549575
return detail::createSyclObjFromImpl<device>(
550576
PlatformImpl.getOrMakeDeviceImpl(UrDevice));
@@ -556,15 +582,15 @@ platform_impl::get_devices(info::device_type DeviceType) const {
556582
MAdapter->call<UrApiKind::urDeviceRelease>(UrDev);
557583

558584
// If we aren't using ONEAPI_DEVICE_SELECTOR, then we are done.
559-
// and if there are no devices so far, there won't be any need to replace them
585+
// and if there are no new devices, there won't be any need to replace them
560586
// with subdevices.
561-
if (!OdsTargetList || Res.size() == 0)
562-
return Res;
587+
if (!OdsTargetList || OutVec.size() == InitialOutVecSize)
588+
return;
563589

564590
// Otherwise, our last step is to revisit the devices, possibly replacing
565591
// them with subdevices (which have been ignored until now)
566-
return amendDeviceAndSubDevices(Backend, Res, OdsTargetList,
567-
PlatformDeviceIndices, PlatformImpl);
592+
OutVec = amendDeviceAndSubDevices(getBackend(), OutVec, OdsTargetList,
593+
PlatformDeviceIndices, PlatformImpl);
568594
}
569595

570596
bool platform_impl::has_extension(const std::string &ExtensionName) const {

sycl/source/detail/platform_impl.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ class platform_impl : public std::enable_shared_from_this<platform_impl> {
203203
private:
204204
device_impl *getDeviceImplHelper(ur_device_handle_t UrDevice);
205205

206+
void getDevicesImplHelper(ur_device_type_t UrDeviceType,
207+
std::vector<device> &OutVec) const;
208+
206209
// Helper to get the vector of platforms supported by a given UR adapter
207210
static std::vector<platform> getAdapterPlatforms(adapter_impl &Adapter,
208211
bool Supported = true);
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
//
4+
// Tests platform::get_devices for each device type.
5+
6+
#include <sycl/detail/core.hpp>
7+
#include <sycl/platform.hpp>
8+
#include <unordered_set>
9+
10+
std::string BackendToString(sycl::backend Backend) {
11+
switch (Backend) {
12+
case sycl::backend::host:
13+
return "host";
14+
case sycl::backend::opencl:
15+
return "opencl";
16+
case sycl::backend::ext_oneapi_level_zero:
17+
return "ext_oneapi_level_zero";
18+
case sycl::backend::ext_oneapi_cuda:
19+
return "ext_oneapi_cuda";
20+
case sycl::backend::all:
21+
return "all";
22+
case sycl::backend::ext_oneapi_hip:
23+
return "ext_oneapi_hip";
24+
case sycl::backend::ext_oneapi_native_cpu:
25+
return "ext_oneapi_native_cpu";
26+
case sycl::backend::ext_oneapi_offload:
27+
return "ext_oneapi_offload";
28+
default:
29+
return "UNKNOWN";
30+
}
31+
}
32+
33+
std::string DeviceTypeToString(sycl::info::device_type DevType) {
34+
switch (DevType) {
35+
case sycl::info::device_type::all:
36+
return "device_type::all";
37+
case sycl::info::device_type::cpu:
38+
return "device_type::cpu";
39+
case sycl::info::device_type::gpu:
40+
return "device_type::gpu";
41+
case sycl::info::device_type::accelerator:
42+
return "device_type::accelerator";
43+
case sycl::info::device_type::custom:
44+
return "device_type::custom";
45+
case sycl::info::device_type::automatic:
46+
return "device_type::automatic";
47+
case sycl::info::device_type::host:
48+
return "device_type::host";
49+
default:
50+
return "UNKNOWN";
51+
}
52+
}
53+
54+
template <typename T1, typename T2>
55+
int Check(const T1 &LHS, const T2 &RHS, std::string TestName) {
56+
if (LHS != RHS) {
57+
std::cout << "Failed check " << LHS << " != " << RHS << ": " << TestName
58+
<< std::endl;
59+
return 1;
60+
}
61+
return 0;
62+
}
63+
64+
int CheckDeviceType(const sycl::platform &P, sycl::info::device_type DevType,
65+
std::unordered_set<sycl::device> &AllDevices) {
66+
assert(DevType != sycl::info::device_type::all);
67+
int Failures = 0;
68+
69+
std::vector<sycl::device> Devices = P.get_devices(DevType);
70+
71+
if (DevType == sycl::info::device_type::automatic) {
72+
if (AllDevices.empty()) {
73+
Failures += Check(
74+
Devices.size(), 0,
75+
"No devices reported for all query, but automatic returns a device.");
76+
} else {
77+
Failures += Check(Devices.size(), 1,
78+
"Number of devices for device_type::automatic query.");
79+
if (Devices.size())
80+
Failures +=
81+
Check(AllDevices.count(Devices[0]), 1,
82+
"Device is in the set of all devices in the platform.");
83+
}
84+
return Failures;
85+
}
86+
87+
// Count devices with the type;
88+
size_t DevCount = 0;
89+
for (sycl::device Device : Devices)
90+
DevCount += (Device.get_info<sycl::info::device::device_type>() == DevType);
91+
92+
std::unordered_set<sycl::device> UniqueDevices{Devices.begin(),
93+
Devices.end()};
94+
Check(Devices.size(), UniqueDevices.size(),
95+
"Duplicate devices for " + DeviceTypeToString(DevType));
96+
97+
Failures +=
98+
Check(Devices.size(), DevCount,
99+
"Unexpected number of devices for " + DeviceTypeToString(DevType));
100+
101+
Failures += Check(
102+
std::all_of(UniqueDevices.begin(), UniqueDevices.end(),
103+
[&](const auto &Dev) { return AllDevices.count(Dev) == 1; }),
104+
true,
105+
"Not all devices for " + DeviceTypeToString(DevType) +
106+
" appear in the list of all devices");
107+
108+
return Failures;
109+
}
110+
111+
int main() {
112+
int Failures = 0;
113+
for (sycl::platform P : sycl::platform::get_platforms()) {
114+
std::cout << "Checking platform with backend "
115+
<< BackendToString(P.get_backend()) << std::endl;
116+
117+
std::vector<sycl::device> Devices = P.get_devices();
118+
std::unordered_set<sycl::device> UniqueDevices{Devices.begin(),
119+
Devices.end()};
120+
121+
if (Check(Devices.size(), UniqueDevices.size(),
122+
"Duplicate devices for device_type::all")) {
123+
++Failures;
124+
// Don't trust this platform, so we continue.
125+
continue;
126+
}
127+
128+
for (sycl::info::device_type DevType :
129+
{sycl::info::device_type::cpu, sycl::info::device_type::gpu,
130+
sycl::info::device_type::accelerator, sycl::info::device_type::custom,
131+
sycl::info::device_type::automatic, sycl::info::device_type::host})
132+
Failures += CheckDeviceType(P, DevType, UniqueDevices);
133+
}
134+
return Failures;
135+
}

unified-runtime/include/ur_api.h

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

unified-runtime/include/ur_print.hpp

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)