@@ -40,7 +40,36 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGet(
40
40
// Filter available devices based on input DeviceType.
41
41
std::vector<ur_device_handle_t > MatchedDevices;
42
42
std::shared_lock<ur_shared_mutex> Lock (Platform->URDevicesCacheMutex );
43
- bool isCombinedMode = false ;
43
+ // We need to filter out composite devices when
44
+ // ZE_FLAT_DEVICE_HIERARCHY=COMBINED. We can know if we are in combined
45
+ // mode depending on the return value of zeDeviceGetRootDevice:
46
+ // - If COMPOSITE, L0 returns cards as devices. Since we filter out
47
+ // subdevices early, zeDeviceGetRootDevice must return nullptr, because we
48
+ // only query for root-devices and they don't have any device higher up in
49
+ // the hierarchy.
50
+ // - If FLAT, according to L0 spec, zeDeviceGetRootDevice always returns
51
+ // nullptr in this mode.
52
+ // - If COMBINED, L0 returns tiles as devices, and zeDeviceGetRootdevice
53
+ // returns the card containing a given tile.
54
+ bool isCombinedMode =
55
+ std::any_of (Platform->URDevicesCache .begin (),
56
+ Platform->URDevicesCache .end (), [](const auto &D) {
57
+ if (D->isSubDevice ())
58
+ return false ;
59
+ ze_device_handle_t RootDev = nullptr ;
60
+ // Query Root Device for root-devices.
61
+ // We cannot use ZE2UR_CALL because under some circumstances
62
+ // this call may return ZE_RESULT_ERROR_UNSUPPORTED_FEATURE,
63
+ // and ZE2UR_CALL will abort because it's not
64
+ // UR_RESULT_SUCCESS. Instead, we use ZE_CALL_NOCHECK and we
65
+ // check manually that the result is either
66
+ // ZE_RESULT_SUCCESS or ZE_RESULT_ERROR_UNSUPPORTED_FEATURE.
67
+ auto errc = ZE_CALL_NOCHECK (zeDeviceGetRootDevice,
68
+ (D->ZeDevice , &RootDev));
69
+ assert (errc == ZE_RESULT_SUCCESS ||
70
+ errc == ZE_RESULT_ERROR_UNSUPPORTED_FEATURE);
71
+ return RootDev != nullptr ;
72
+ });
44
73
for (auto &D : Platform->URDevicesCache ) {
45
74
// Only ever return root-devices from urDeviceGet, but the
46
75
// devices cache also keeps sub-devices.
@@ -71,56 +100,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGet(
71
100
break ;
72
101
}
73
102
74
- // We need to filter out composite devices depending on
75
- // ZE_FLAT_DEVICE_HIERARCHY value:
76
- // - If COMPOSITE, L0 returns cards as devices. Thus, zeGetRootDevice must
77
- // return nullptr, because they don't have any device higher up in the
78
- // hierarchy.
79
- // - If FLAT, according to L0 spec, zeGetRootDevice always returns
80
- // nullptr in this mode.
81
- // - If COMBINED, L0 returns tiles as devices, and zeGetRootdevice returns
82
- // the card containing a given tile.
83
- //
84
- // NOTE: We cannot directly filter out the composite devices here because we
85
- // might have the composite device appear earlier than we know we are in
86
- // combined mode, so we simply try and infer here if we are in combined
87
- // mode, and then, if so, remove composite devices from MatchedDevices.
88
- if (!isCombinedMode) {
89
- ze_device_handle_t RootDev = nullptr ;
90
- // Query Root Device
91
- // We cannot use ZE2UR_CALL because under some circumstances this call may
92
- // return ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, and ZE2UR_CALL will abort
93
- // because it's not UR_RESULT_SUCCESS. Instead, we use ZE_CALL_NOCHECK and
94
- // we check manually that the result is either ZE_RESULT_SUCCESS or
95
- // ZE_RESULT_ERROR_UNSUPPORTED_FEATURE.
96
- auto errc =
97
- ZE_CALL_NOCHECK (zeDeviceGetRootDevice, (D->ZeDevice , &RootDev));
98
- if (errc != ZE_RESULT_SUCCESS &&
99
- errc != ZE_RESULT_ERROR_UNSUPPORTED_FEATURE)
100
- return ze2urResult (errc);
101
- // For COMPOSITE and FLAT modes, RootDev will always be nullptr. Thus a
102
- // single device returning RootDev != nullptr means we are in COMBINED
103
- // mode.
104
- isCombinedMode = (RootDev != nullptr );
105
- }
106
-
107
103
if (Matched) {
108
- MatchedDevices.push_back (D.get ());
109
- }
110
- }
111
-
112
- if (isCombinedMode) {
113
- // Effectively filter out composite devices.
114
- std::vector<std::vector<ur_device_handle_t >::iterator> toDelete;
115
- for (auto it = MatchedDevices.begin (); it != MatchedDevices.end (); ++it) {
116
- const auto &D = *it;
117
- bool isComposite = (D->ZeDeviceProperties ->flags &
118
- ZE_DEVICE_PROPERTY_FLAG_SUBDEVICE) == 0 ;
119
- if (isComposite)
120
- toDelete.push_back (it);
121
- }
122
- for (const auto D : toDelete) {
123
- MatchedDevices.erase (D);
104
+ bool isComposite =
105
+ isCombinedMode && (D->ZeDeviceProperties ->flags &
106
+ ZE_DEVICE_PROPERTY_FLAG_SUBDEVICE) == 0 ;
107
+ if (!isComposite)
108
+ MatchedDevices.push_back (D.get ());
124
109
}
125
110
}
126
111
@@ -888,7 +873,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(
888
873
(DevHandle, &SubDeviceCount, SubDevs.data ()));
889
874
890
875
size_t SubDeviceCount_s{SubDeviceCount};
891
- auto ResSize = std::min (SubDeviceCount_s, propSize);
876
+ auto ResSize =
877
+ std::min (SubDeviceCount_s, propSize / sizeof (ur_device_handle_t ));
892
878
std::vector<ur_device_handle_t > Res;
893
879
for (const auto &d : SubDevs) {
894
880
// We can only reach this code if ZE_FLAT_DEVICE_HIERARCHY != FLAT,
@@ -920,9 +906,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(
920
906
ur_device_handle_t UrRootDev = nullptr ;
921
907
ze_device_handle_t DevHandle = Device->ZeDevice ;
922
908
ze_device_handle_t RootDev;
923
- // Query Root Device
924
- ZE2UR_CALL (zeDeviceGetRootDevice, (DevHandle, &RootDev));
909
+ // Query Root Device.
910
+ auto errc = ZE_CALL_NOCHECK (zeDeviceGetRootDevice, (DevHandle, &RootDev));
925
911
UrRootDev = Device->Platform ->getDeviceFromNativeHandle (RootDev);
912
+ if (errc != ZE_RESULT_SUCCESS &&
913
+ errc != ZE_RESULT_ERROR_UNSUPPORTED_FEATURE)
914
+ return ze2urResult (errc);
926
915
return ReturnValue (UrRootDev);
927
916
}
928
917
0 commit comments