@@ -248,24 +248,25 @@ platform_impl::filterDeviceFilter(std::vector<ur_device_handle_t> &UrDevices,
248248 MAdapter->call <UrApiKind::urDeviceGetInfo>(Device, UR_DEVICE_INFO_TYPE,
249249 sizeof (ur_device_type_t ),
250250 &UrDevType, nullptr );
251- // Assumption here is that there is 1-to-1 mapping between UrDevType and
252- // Sycl device type for GPU, CPU, and ACC.
253- info::device_type DeviceType = info::device_type::all;
254- switch (UrDevType) {
255- default :
256- case UR_DEVICE_TYPE_ALL:
257- DeviceType = info::device_type::all;
258- break ;
259- case UR_DEVICE_TYPE_GPU:
260- DeviceType = info::device_type::gpu;
261- break ;
262- case UR_DEVICE_TYPE_CPU:
263- DeviceType = info::device_type::cpu;
264- break ;
265- case UR_DEVICE_TYPE_FPGA:
266- DeviceType = info::device_type::accelerator;
267- break ;
268- }
251+ info::device_type DeviceType = [UrDevType]() {
252+ switch (UrDevType) {
253+ default :
254+ case UR_DEVICE_TYPE_ALL:
255+ return info::device_type::all;
256+ case UR_DEVICE_TYPE_GPU:
257+ return info::device_type::gpu;
258+ case UR_DEVICE_TYPE_CPU:
259+ return info::device_type::cpu;
260+ case UR_DEVICE_TYPE_FPGA:
261+ return info::device_type::accelerator;
262+ case UR_DEVICE_TYPE_CUSTOM:
263+ case UR_DEVICE_TYPE_MCA:
264+ case UR_DEVICE_TYPE_VPU:
265+ return info::device_type::custom;
266+ case UR_DEVICE_TYPE_DEFAULT:
267+ return info::device_type::automatic;
268+ }
269+ }();
269270
270271 for (const FilterT &Filter : FilterList->get ()) {
271272 backend FilterBackend = Filter.Backend .value_or (backend::all);
@@ -469,34 +470,57 @@ static std::vector<device> amendDeviceAndSubDevices(
469470std::vector<device>
470471platform_impl::get_devices (info::device_type DeviceType) const {
471472 std::vector<device> Res;
472-
473- ods_target_list *OdsTargetList = SYCLConfig<ONEAPI_DEVICE_SELECTOR>::get ();
473+ // Host is no longer supported, so it returns an empty vector.
474474 if (DeviceType == info::device_type::host)
475+ return std::vector<device>{};
476+
477+ // For custom devices, UR has additional type enums.
478+ if (DeviceType == info::device_type::custom) {
479+ getDevicesImplHelper (UR_DEVICE_TYPE_CUSTOM, Res);
480+ getDevicesImplHelper (UR_DEVICE_TYPE_MCA, Res);
481+ getDevicesImplHelper (UR_DEVICE_TYPE_VPU, Res);
482+ // Some backends may return the MCA and VPU types as part of custom, so
483+ // remove duplicates.
484+ std::sort (Res.begin (), Res.end (),
485+ [](const sycl::device &D1, const sycl::device &D2) {
486+ std::hash<sycl::device> Hasher;
487+ return Hasher (D1) < Hasher (D2);
488+ });
489+ auto NewEnd = std::unique (Res.begin (), Res.end ());
490+ Res.erase (NewEnd, Res.end ());
475491 return Res;
492+ }
476493
477- ur_device_type_t UrDeviceType = UR_DEVICE_TYPE_ALL;
478-
494+ ur_device_type_t UrDeviceType = [DeviceType]() {
479495 switch (DeviceType) {
480- default :
481496 case info::device_type::all:
482- UrDeviceType = UR_DEVICE_TYPE_ALL;
483- break ;
497+ return UR_DEVICE_TYPE_ALL;
484498 case info::device_type::gpu:
485- UrDeviceType = UR_DEVICE_TYPE_GPU;
486- break ;
499+ return UR_DEVICE_TYPE_GPU;
487500 case info::device_type::cpu:
488- UrDeviceType = UR_DEVICE_TYPE_CPU;
489- break ;
501+ return UR_DEVICE_TYPE_CPU;
490502 case info::device_type::accelerator:
491- UrDeviceType = UR_DEVICE_TYPE_FPGA;
492- break ;
503+ return UR_DEVICE_TYPE_FPGA;
504+ case info::device_type::automatic:
505+ return UR_DEVICE_TYPE_DEFAULT;
506+ default :
507+ throw sycl::exception (sycl::make_error_code (sycl::errc::invalid),
508+ " Unknown device type." );
493509 }
510+ }();
511+ getDevicesImplHelper (UrDeviceType, Res);
512+ return Res;
513+ }
514+
515+ void
516+ platform_impl::getDevicesImplHelper (ur_device_type_t UrDeviceType,
517+ std::vector<device> &OutVec) const {
518+ size_t InitialOutVecSize = OutVec.size ();
494519
495520 uint32_t NumDevices = 0 ;
496521 MAdapter->call <UrApiKind::urDeviceGet>(MPlatform, UrDeviceType,
497522 0u , // CP info::device_type::all
498523 nullptr , &NumDevices);
499- const backend Backend = getBackend ();
500524
501525 if (NumDevices == 0 ) {
502526 // If platform doesn't have devices (even without filter)
@@ -514,7 +538,7 @@ platform_impl::get_devices(info::device_type DeviceType) const {
514538 std::lock_guard<std::mutex> Guard (*Adapter->getAdapterMutex ());
515539 Adapter->adjustLastDeviceId (MPlatform);
516540 }
517- return Res ;
541+ return ;
518542 }
519543
520544 std::vector<ur_device_handle_t > UrDevices (NumDevices);
@@ -532,6 +556,8 @@ platform_impl::get_devices(info::device_type DeviceType) const {
532556 if (SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get ())
533557 applyAllowList (UrDevices, MPlatform, *MAdapter);
534558
559+ ods_target_list *OdsTargetList = SYCLConfig<ONEAPI_DEVICE_SELECTOR>::get ();
560+
535561 // The first step is to filter out devices that are not compatible with
536562 // ONEAPI_DEVICE_SELECTOR. This is also the mechanism by which top level
537563 // device ids are assigned.
@@ -544,7 +570,7 @@ platform_impl::get_devices(info::device_type DeviceType) const {
544570 // The next step is to inflate the filtered UrDevices into SYCL Device
545571 // objects.
546572 platform_impl &PlatformImpl = getOrMakePlatformImpl (MPlatform, *MAdapter);
547- std::transform (UrDevices.begin (), UrDevices.end (), std::back_inserter (Res ),
573+ std::transform (UrDevices.begin (), UrDevices.end (), std::back_inserter (OutVec ),
548574 [&PlatformImpl](const ur_device_handle_t UrDevice) -> device {
549575 return detail::createSyclObjFromImpl<device>(
550576 PlatformImpl.getOrMakeDeviceImpl (UrDevice));
@@ -556,15 +582,15 @@ platform_impl::get_devices(info::device_type DeviceType) const {
556582 MAdapter->call <UrApiKind::urDeviceRelease>(UrDev);
557583
558584 // If we aren't using ONEAPI_DEVICE_SELECTOR, then we are done.
559- // and if there are no devices so far , there won't be any need to replace them
585+ // and if there are no new devices , there won't be any need to replace them
560586 // with subdevices.
561- if (!OdsTargetList || Res .size () == 0 )
562- return Res ;
587+ if (!OdsTargetList || OutVec .size () == InitialOutVecSize )
588+ return ;
563589
564590 // Otherwise, our last step is to revisit the devices, possibly replacing
565591 // them with subdevices (which have been ignored until now)
566- return amendDeviceAndSubDevices (Backend, Res , OdsTargetList,
567- PlatformDeviceIndices, PlatformImpl);
592+ OutVec = amendDeviceAndSubDevices (getBackend (), OutVec , OdsTargetList,
593+ PlatformDeviceIndices, PlatformImpl);
568594}
569595
570596bool platform_impl::has_extension (const std::string &ExtensionName) const {
0 commit comments