Skip to content

Commit 2881074

Browse files
authored
Support for ZEL_DRIVERS_ORDER to order based off user input (#371)
* Fix unit tests for static builds * Remove unit test build/run on windows dynamic loader * Add changes to the internal headers * Move helper functions to new header and update tests Signed-off-by: Neil R. Spruit <[email protected]>
1 parent f874a09 commit 2881074

File tree

8 files changed

+2074
-6
lines changed

8 files changed

+2074
-6
lines changed

scripts/templates/ze_loader_internal.h.mako

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ namespace loader
4747
ZEL_DRIVER_TYPE_INTEGRATED_GPU = 2, ///< The driver has Integrated GPUs only
4848
ZEL_DRIVER_TYPE_MIXED = 3, ///< The driver has Heterogenous driver types not limited to GPU or NPU.
4949
ZEL_DRIVER_TYPE_OTHER = 4, ///< The driver has No GPU Devices and has other device types only
50+
ZEL_DRIVER_TYPE_NPU = 5, ///< The driver has NPU devices only
5051
ZEL_DRIVER_TYPE_FORCE_UINT32 = 0x7fffffff
5152

5253
} zel_driver_type_t;
@@ -114,6 +115,7 @@ namespace loader
114115
ze_result_t init_driver(driver_t &driver, ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool sysmanOnly);
115116
void add_loader_version();
116117
bool driverSorting(driver_vector_t *drivers, ze_init_driver_type_desc_t* desc, bool sysmanOnly);
118+
void driverOrdering(driver_vector_t *drivers);
117119
~context_t();
118120
bool intercept_enabled = false;
119121
bool debugTraceEnabled = false;

source/loader/ze_loader.cpp

Lines changed: 180 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* SPDX-License-Identifier: MIT
66
*
77
*/
8-
#include "ze_loader_internal.h"
8+
#include "ze_loader_utils.h"
99

1010
#include "driver_discovery.h"
1111
#include <iostream>
@@ -72,6 +72,179 @@ namespace loader
7272
return a.driverType < b.driverType;
7373
}
7474

75+
void context_t::driverOrdering(driver_vector_t *drivers) {
76+
std::string orderStr = getenv_string("ZEL_DRIVERS_ORDER");
77+
if (orderStr.empty()) {
78+
return; // No ordering specified
79+
}
80+
81+
std::vector<DriverOrderSpec> specs = parseDriverOrder(orderStr);
82+
83+
if (specs.empty()) {
84+
if (debugTraceEnabled) {
85+
std::string message = "driverOrdering: ZEL_DRIVERS_ORDER parsing failed or empty: " + orderStr;
86+
debug_trace_message(message, "");
87+
}
88+
return;
89+
}
90+
91+
if (debugTraceEnabled) {
92+
std::string message = "driverOrdering:ZEL_DRIVERS_ORDER parsing successful: " + orderStr + ", specs count: " + std::to_string(specs.size());
93+
debug_trace_message(message, "");
94+
}
95+
96+
// Create a copy of the original driver vector for reference
97+
driver_vector_t originalDrivers = *drivers;
98+
99+
driver_vector_t discreteGPUDrivers;
100+
driver_vector_t integratedGPUDrivers;
101+
driver_vector_t npuDrivers;
102+
driver_vector_t gpuDrivers;
103+
104+
std::vector<uint32_t> discreteGPUIndices;
105+
std::vector<uint32_t> integratedGPUIndices;
106+
std::vector<uint32_t> npuIndices;
107+
std::vector<uint32_t> gpuIndices;
108+
109+
// Group drivers by type and track their original indices
110+
for (uint32_t i = 0; i < originalDrivers.size(); ++i) {
111+
const auto& driver = originalDrivers[i];
112+
switch (driver.driverType) {
113+
case ZEL_DRIVER_TYPE_DISCRETE_GPU:
114+
discreteGPUDrivers.push_back(driver);
115+
discreteGPUIndices.push_back(i);
116+
break;
117+
case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
118+
integratedGPUDrivers.push_back(driver);
119+
integratedGPUIndices.push_back(i);
120+
break;
121+
case ZEL_DRIVER_TYPE_GPU:
122+
gpuDrivers.push_back(driver);
123+
gpuIndices.push_back(i);
124+
break;
125+
case ZEL_DRIVER_TYPE_NPU:
126+
npuDrivers.push_back(driver);
127+
npuIndices.push_back(i);
128+
break;
129+
case ZEL_DRIVER_TYPE_OTHER:
130+
npuDrivers.push_back(driver);
131+
npuIndices.push_back(i);
132+
break;
133+
case ZEL_DRIVER_TYPE_MIXED:
134+
// Mixed drivers go to gpuDrivers
135+
gpuDrivers.push_back(driver);
136+
gpuIndices.push_back(i);
137+
break;
138+
default:
139+
break;
140+
}
141+
}
142+
143+
// Create new ordered driver vector
144+
driver_vector_t orderedDrivers;
145+
std::set<uint32_t> usedGlobalIndices;
146+
std::set<std::pair<zel_driver_type_t, uint32_t>> usedTypeIndices;
147+
148+
// Apply ordering specifications
149+
for (const auto& spec : specs) {
150+
switch (spec.type) {
151+
case DriverOrderSpecType::BY_GLOBAL_INDEX:
152+
if (spec.globalIndex < originalDrivers.size() &&
153+
usedGlobalIndices.find(spec.globalIndex) == usedGlobalIndices.end()) {
154+
orderedDrivers.push_back(originalDrivers[spec.globalIndex]);
155+
usedGlobalIndices.insert(spec.globalIndex);
156+
}
157+
break;
158+
159+
case DriverOrderSpecType::BY_TYPE:
160+
// Add all drivers of this type that haven't been used
161+
{
162+
std::vector<uint32_t>* typeIndices = nullptr;
163+
switch (spec.driverType) {
164+
case ZEL_DRIVER_TYPE_DISCRETE_GPU:
165+
typeIndices = &discreteGPUIndices;
166+
break;
167+
case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
168+
typeIndices = &integratedGPUIndices;
169+
break;
170+
case ZEL_DRIVER_TYPE_GPU:
171+
typeIndices = &gpuIndices;
172+
break;
173+
case ZEL_DRIVER_TYPE_NPU:
174+
case ZEL_DRIVER_TYPE_OTHER:
175+
typeIndices = &npuIndices;
176+
break;
177+
default:
178+
break;
179+
}
180+
181+
if (typeIndices) {
182+
for (uint32_t globalIdx : *typeIndices) {
183+
if (usedGlobalIndices.find(globalIdx) == usedGlobalIndices.end()) {
184+
orderedDrivers.push_back(originalDrivers[globalIdx]);
185+
usedGlobalIndices.insert(globalIdx);
186+
}
187+
}
188+
}
189+
}
190+
break;
191+
192+
case DriverOrderSpecType::BY_TYPE_AND_INDEX:
193+
{
194+
std::vector<uint32_t>* typeIndices = nullptr;
195+
switch (spec.driverType) {
196+
case ZEL_DRIVER_TYPE_DISCRETE_GPU:
197+
typeIndices = &discreteGPUIndices;
198+
break;
199+
case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
200+
typeIndices = &integratedGPUIndices;
201+
break;
202+
case ZEL_DRIVER_TYPE_GPU:
203+
typeIndices = &gpuIndices;
204+
break;
205+
case ZEL_DRIVER_TYPE_NPU:
206+
case ZEL_DRIVER_TYPE_OTHER:
207+
typeIndices = &npuIndices;
208+
break;
209+
default:
210+
break;
211+
}
212+
213+
if (typeIndices && spec.typeIndex < typeIndices->size()) {
214+
auto typeIndexPair = std::make_pair(spec.driverType, spec.typeIndex);
215+
if (usedTypeIndices.find(typeIndexPair) == usedTypeIndices.end()) {
216+
uint32_t globalIdx = (*typeIndices)[spec.typeIndex];
217+
if (usedGlobalIndices.find(globalIdx) == usedGlobalIndices.end()) {
218+
orderedDrivers.push_back(originalDrivers[globalIdx]);
219+
usedGlobalIndices.insert(globalIdx);
220+
usedTypeIndices.insert(typeIndexPair);
221+
}
222+
}
223+
}
224+
}
225+
break;
226+
}
227+
}
228+
229+
// Add remaining drivers in their original order
230+
for (uint32_t i = 0; i < originalDrivers.size(); ++i) {
231+
if (usedGlobalIndices.find(i) == usedGlobalIndices.end()) {
232+
orderedDrivers.push_back(originalDrivers[i]);
233+
}
234+
}
235+
236+
// Replace the original driver vector with the ordered one
237+
*drivers = orderedDrivers;
238+
239+
if (debugTraceEnabled) {
240+
std::string message = "driverOrdering: Drivers after ZEL_DRIVERS_ORDER:";
241+
for (uint32_t i = 0; i < drivers->size(); ++i) {
242+
message += "\n[" + std::to_string(i) + "] Driver Type: " + std::to_string((*drivers)[i].driverType) + " Driver Name: " + (*drivers)[i].name;
243+
}
244+
debug_trace_message(message, "");
245+
}
246+
}
247+
75248
bool context_t::driverSorting(driver_vector_t *drivers, ze_init_driver_type_desc_t* desc, bool sysmanOnly) {
76249
ze_init_driver_type_desc_t permissiveDesc = {};
77250
permissiveDesc.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC;
@@ -246,6 +419,10 @@ namespace loader
246419
}
247420
debug_trace_message(message, "");
248421
}
422+
423+
// Apply driver ordering based on ZEL_DRIVERS_ORDER environment variable
424+
driverOrdering(drivers);
425+
249426
return true;
250427
}
251428

@@ -577,7 +754,7 @@ namespace loader
577754
GET_FUNCTION_PTR(validationLayer, "zelLoaderGetVersion"));
578755
zel_component_version_t compVersion;
579756
if(getVersion && ZE_RESULT_SUCCESS == getVersion(&compVersion))
580-
{
757+
{
581758
compVersions.push_back(compVersion);
582759
}
583760
} else if (debugTraceEnabled) {
@@ -602,7 +779,7 @@ namespace loader
602779
GET_FUNCTION_PTR(tracingLayer, "zelLoaderGetVersion"));
603780
zel_component_version_t compVersion;
604781
if(getVersion && ZE_RESULT_SUCCESS == getVersion(&compVersion))
605-
{
782+
{
606783
compVersions.push_back(compVersion);
607784
}
608785
} else if (debugTraceEnabled) {

source/loader/ze_loader_internal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ namespace loader
3838
ZEL_DRIVER_TYPE_INTEGRATED_GPU = 2, ///< The driver has Integrated GPUs only
3939
ZEL_DRIVER_TYPE_MIXED = 3, ///< The driver has Heterogenous driver types not limited to GPU or NPU.
4040
ZEL_DRIVER_TYPE_OTHER = 4, ///< The driver has No GPU Devices and has other device types only
41+
ZEL_DRIVER_TYPE_NPU = 5, ///< The driver has NPU devices only
4142
ZEL_DRIVER_TYPE_FORCE_UINT32 = 0x7fffffff
4243

4344
} zel_driver_type_t;
@@ -150,6 +151,7 @@ namespace loader
150151
ze_result_t init_driver(driver_t &driver, ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool sysmanOnly);
151152
void add_loader_version();
152153
bool driverSorting(driver_vector_t *drivers, ze_init_driver_type_desc_t* desc, bool sysmanOnly);
154+
void driverOrdering(driver_vector_t *drivers);
153155
~context_t();
154156
bool intercept_enabled = false;
155157
bool debugTraceEnabled = false;

source/loader/ze_loader_utils.h

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
*
3+
* Copyright (C) 2025 Intel Corporation
4+
*
5+
* SPDX-License-Identifier: MIT
6+
*
7+
*/
8+
#pragma once
9+
#include "ze_loader_internal.h"
10+
#include <string>
11+
#include <vector>
12+
#include <map>
13+
#include <set>
14+
#include <sstream>
15+
#include <cstdlib>
16+
#include <algorithm>
17+
18+
19+
namespace loader
20+
{
21+
// Helper function to map driver type string to enum
22+
inline zel_driver_type_t stringToDriverType(const std::string& typeStr) {
23+
if (typeStr == "DISCRETE_GPU_ONLY") {
24+
return ZEL_DRIVER_TYPE_DISCRETE_GPU;
25+
} else if (typeStr == "GPU") {
26+
return ZEL_DRIVER_TYPE_GPU;
27+
} else if (typeStr == "INTEGRATED_GPU_ONLY") {
28+
return ZEL_DRIVER_TYPE_INTEGRATED_GPU;
29+
} else if (typeStr == "NPU") {
30+
return ZEL_DRIVER_TYPE_NPU;
31+
}
32+
return ZEL_DRIVER_TYPE_FORCE_UINT32; // Invalid
33+
}
34+
35+
// Helper function to trim whitespace
36+
inline std::string trim(const std::string& str) {
37+
const std::string whitespace = " \t\n\r\f\v";
38+
size_t start = str.find_first_not_of(whitespace);
39+
if (start == std::string::npos) return "";
40+
size_t end = str.find_last_not_of(whitespace);
41+
return str.substr(start, end - start + 1);
42+
}
43+
44+
enum DriverOrderSpecType { BY_GLOBAL_INDEX, BY_TYPE, BY_TYPE_AND_INDEX };
45+
46+
// Structure to hold parsed ordering instructions
47+
struct DriverOrderSpec {
48+
DriverOrderSpecType type;
49+
uint32_t globalIndex = 0;
50+
zel_driver_type_t driverType = ZEL_DRIVER_TYPE_FORCE_UINT32;
51+
uint32_t typeIndex = 0;
52+
};
53+
54+
// Parse ZEL_DRIVERS_ORDER environment variable
55+
inline std::vector<DriverOrderSpec> parseDriverOrder(const std::string& orderStr) {
56+
std::vector<DriverOrderSpec> specs;
57+
58+
// Split by comma
59+
std::vector<std::string> tokens;
60+
std::stringstream ss(orderStr);
61+
std::string token;
62+
63+
while (std::getline(ss, token, ',')) {
64+
token = trim(token);
65+
if (token.empty()) continue;
66+
67+
DriverOrderSpec spec;
68+
69+
// Check if it contains a colon (type:index format)
70+
size_t colonPos = token.find(':');
71+
if (colonPos != std::string::npos) {
72+
// Format: <driver_type>:<driver_index>
73+
std::string typeStr = trim(token.substr(0, colonPos));
74+
std::string indexStr = trim(token.substr(colonPos + 1));
75+
76+
spec.driverType = stringToDriverType(typeStr);
77+
if (spec.driverType == ZEL_DRIVER_TYPE_FORCE_UINT32) {
78+
continue; // Invalid driver type, skip
79+
}
80+
81+
try {
82+
spec.typeIndex = std::stoul(indexStr);
83+
spec.type = DriverOrderSpecType::BY_TYPE_AND_INDEX;
84+
specs.push_back(spec);
85+
} catch (const std::exception&) {
86+
// Invalid index, skip
87+
continue;
88+
}
89+
} else {
90+
// Check if it's a pure number (global index) or driver type
91+
try {
92+
spec.globalIndex = std::stoul(token);
93+
spec.type = DriverOrderSpecType::BY_GLOBAL_INDEX;
94+
specs.push_back(spec);
95+
} catch (const std::exception&) {
96+
// Not a number, try as driver type
97+
spec.driverType = stringToDriverType(token);
98+
if (spec.driverType != ZEL_DRIVER_TYPE_FORCE_UINT32) {
99+
spec.type = DriverOrderSpecType::BY_TYPE;
100+
specs.push_back(spec);
101+
}
102+
}
103+
}
104+
}
105+
106+
return specs;
107+
}
108+
} // namespace loader

0 commit comments

Comments
 (0)