@@ -43,15 +43,31 @@ ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {
4343 }
4444
4545 std::vector<ze_driver_handle_t > ZeDrivers;
46+ std::vector<ze_device_handle_t > ZeDevices;
4647 ZeDrivers.resize (ZeDriverCount);
4748
4849 ZE2UR_CALL (zeDriverGet, (&ZeDriverCount, ZeDrivers.data ()));
4950 for (uint32_t I = 0 ; I < ZeDriverCount; ++I) {
50- auto platform = std::make_unique<ur_platform_handle_t_>(ZeDrivers[I]);
51- UR_CALL (platform->initialize ());
52-
53- // Save a copy in the cache for future uses.
54- platforms.push_back (std::move (platform));
51+ ze_device_properties_t device_properties{};
52+ device_properties.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES;
53+ uint32_t ZeDeviceCount = 0 ;
54+ ZE2UR_CALL (zeDeviceGet, (ZeDrivers[I], &ZeDeviceCount, nullptr ));
55+ ZeDevices.resize (ZeDeviceCount);
56+ ZE2UR_CALL (zeDeviceGet, (ZeDrivers[I], &ZeDeviceCount, ZeDevices.data ()));
57+ // Check if this driver has GPU Devices
58+ for (uint32_t D = 0 ; D < ZeDeviceCount; ++D) {
59+ ZE2UR_CALL (zeDeviceGetProperties, (ZeDevices[D], &device_properties));
60+
61+ if (ZE_DEVICE_TYPE_GPU == device_properties.type ) {
62+ // If this Driver is a GPU, save it as a usable platform.
63+ auto platform = std::make_unique<ur_platform_handle_t_>(ZeDrivers[I]);
64+ UR_CALL (platform->initialize ());
65+
66+ // Save a copy in the cache for future uses.
67+ platforms.push_back (std::move (platform));
68+ break ;
69+ }
70+ }
5571 }
5672 return UR_RESULT_SUCCESS;
5773} catch (...) {
@@ -105,8 +121,16 @@ ur_adapter_handle_t_::ur_adapter_handle_t_()
105121 // We must only initialize the driver once, even if urPlatformGet() is
106122 // called multiple times. Declaring the return value as "static" ensures
107123 // it's only called once.
108- GlobalAdapter->ZeResult =
109- ZE_CALL_NOCHECK (zeInit, (ZE_INIT_FLAG_GPU_ONLY));
124+
125+ // Init with all flags set to enable for all driver types to be init in
126+ // the application.
127+ ze_init_flags_t L0InitFlags = ZE_INIT_FLAG_GPU_ONLY;
128+ if (UrL0InitAllDrivers) {
129+ L0InitFlags |= ZE_INIT_FLAG_VPU_ONLY;
130+ }
131+ logger::debug (" \n zeInit with flags value of {}\n " ,
132+ static_cast <int >(L0InitFlags));
133+ GlobalAdapter->ZeResult = ZE_CALL_NOCHECK (zeInit, (L0InitFlags));
110134 }
111135 assert (GlobalAdapter->ZeResult !=
112136 std::nullopt ); // verify that level-zero is initialized
0 commit comments