55 * SPDX-License-Identifier: MIT
66 *
77 */
8- #include " ze_loader_internal .h"
8+ #include " ze_loader_utils .h"
99
1010#include " driver_discovery.h"
1111#include < iostream>
@@ -72,6 +72,179 @@ namespace loader
7272 return a.driverType < b.driverType ;
7373 }
7474
75+ void context_t::driverOrdering (driver_vector_t *drivers) {
76+ std::string orderStr = getenv_string (" ZEL_DRIVERS_ORDER" );
77+ if (orderStr.empty ()) {
78+ return ; // No ordering specified
79+ }
80+
81+ std::vector<DriverOrderSpec> specs = parseDriverOrder (orderStr);
82+
83+ if (specs.empty ()) {
84+ if (debugTraceEnabled) {
85+ std::string message = " driverOrdering: ZEL_DRIVERS_ORDER parsing failed or empty: " + orderStr;
86+ debug_trace_message (message, " " );
87+ }
88+ return ;
89+ }
90+
91+ if (debugTraceEnabled) {
92+ std::string message = " driverOrdering:ZEL_DRIVERS_ORDER parsing successful: " + orderStr + " , specs count: " + std::to_string (specs.size ());
93+ debug_trace_message (message, " " );
94+ }
95+
96+ // Create a copy of the original driver vector for reference
97+ driver_vector_t originalDrivers = *drivers;
98+
99+ driver_vector_t discreteGPUDrivers;
100+ driver_vector_t integratedGPUDrivers;
101+ driver_vector_t npuDrivers;
102+ driver_vector_t gpuDrivers;
103+
104+ std::vector<uint32_t > discreteGPUIndices;
105+ std::vector<uint32_t > integratedGPUIndices;
106+ std::vector<uint32_t > npuIndices;
107+ std::vector<uint32_t > gpuIndices;
108+
109+ // Group drivers by type and track their original indices
110+ for (uint32_t i = 0 ; i < originalDrivers.size (); ++i) {
111+ const auto & driver = originalDrivers[i];
112+ switch (driver.driverType ) {
113+ case ZEL_DRIVER_TYPE_DISCRETE_GPU:
114+ discreteGPUDrivers.push_back (driver);
115+ discreteGPUIndices.push_back (i);
116+ break ;
117+ case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
118+ integratedGPUDrivers.push_back (driver);
119+ integratedGPUIndices.push_back (i);
120+ break ;
121+ case ZEL_DRIVER_TYPE_GPU:
122+ gpuDrivers.push_back (driver);
123+ gpuIndices.push_back (i);
124+ break ;
125+ case ZEL_DRIVER_TYPE_NPU:
126+ npuDrivers.push_back (driver);
127+ npuIndices.push_back (i);
128+ break ;
129+ case ZEL_DRIVER_TYPE_OTHER:
130+ npuDrivers.push_back (driver);
131+ npuIndices.push_back (i);
132+ break ;
133+ case ZEL_DRIVER_TYPE_MIXED:
134+ // Mixed drivers go to gpuDrivers
135+ gpuDrivers.push_back (driver);
136+ gpuIndices.push_back (i);
137+ break ;
138+ default :
139+ break ;
140+ }
141+ }
142+
143+ // Create new ordered driver vector
144+ driver_vector_t orderedDrivers;
145+ std::set<uint32_t > usedGlobalIndices;
146+ std::set<std::pair<zel_driver_type_t , uint32_t >> usedTypeIndices;
147+
148+ // Apply ordering specifications
149+ for (const auto & spec : specs) {
150+ switch (spec.type ) {
151+ case DriverOrderSpecType::BY_GLOBAL_INDEX:
152+ if (spec.globalIndex < originalDrivers.size () &&
153+ usedGlobalIndices.find (spec.globalIndex ) == usedGlobalIndices.end ()) {
154+ orderedDrivers.push_back (originalDrivers[spec.globalIndex ]);
155+ usedGlobalIndices.insert (spec.globalIndex );
156+ }
157+ break ;
158+
159+ case DriverOrderSpecType::BY_TYPE:
160+ // Add all drivers of this type that haven't been used
161+ {
162+ std::vector<uint32_t >* typeIndices = nullptr ;
163+ switch (spec.driverType ) {
164+ case ZEL_DRIVER_TYPE_DISCRETE_GPU:
165+ typeIndices = &discreteGPUIndices;
166+ break ;
167+ case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
168+ typeIndices = &integratedGPUIndices;
169+ break ;
170+ case ZEL_DRIVER_TYPE_GPU:
171+ typeIndices = &gpuIndices;
172+ break ;
173+ case ZEL_DRIVER_TYPE_NPU:
174+ case ZEL_DRIVER_TYPE_OTHER:
175+ typeIndices = &npuIndices;
176+ break ;
177+ default :
178+ break ;
179+ }
180+
181+ if (typeIndices) {
182+ for (uint32_t globalIdx : *typeIndices) {
183+ if (usedGlobalIndices.find (globalIdx) == usedGlobalIndices.end ()) {
184+ orderedDrivers.push_back (originalDrivers[globalIdx]);
185+ usedGlobalIndices.insert (globalIdx);
186+ }
187+ }
188+ }
189+ }
190+ break ;
191+
192+ case DriverOrderSpecType::BY_TYPE_AND_INDEX:
193+ {
194+ std::vector<uint32_t >* typeIndices = nullptr ;
195+ switch (spec.driverType ) {
196+ case ZEL_DRIVER_TYPE_DISCRETE_GPU:
197+ typeIndices = &discreteGPUIndices;
198+ break ;
199+ case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
200+ typeIndices = &integratedGPUIndices;
201+ break ;
202+ case ZEL_DRIVER_TYPE_GPU:
203+ typeIndices = &gpuIndices;
204+ break ;
205+ case ZEL_DRIVER_TYPE_NPU:
206+ case ZEL_DRIVER_TYPE_OTHER:
207+ typeIndices = &npuIndices;
208+ break ;
209+ default :
210+ break ;
211+ }
212+
213+ if (typeIndices && spec.typeIndex < typeIndices->size ()) {
214+ auto typeIndexPair = std::make_pair (spec.driverType , spec.typeIndex );
215+ if (usedTypeIndices.find (typeIndexPair) == usedTypeIndices.end ()) {
216+ uint32_t globalIdx = (*typeIndices)[spec.typeIndex ];
217+ if (usedGlobalIndices.find (globalIdx) == usedGlobalIndices.end ()) {
218+ orderedDrivers.push_back (originalDrivers[globalIdx]);
219+ usedGlobalIndices.insert (globalIdx);
220+ usedTypeIndices.insert (typeIndexPair);
221+ }
222+ }
223+ }
224+ }
225+ break ;
226+ }
227+ }
228+
229+ // Add remaining drivers in their original order
230+ for (uint32_t i = 0 ; i < originalDrivers.size (); ++i) {
231+ if (usedGlobalIndices.find (i) == usedGlobalIndices.end ()) {
232+ orderedDrivers.push_back (originalDrivers[i]);
233+ }
234+ }
235+
236+ // Replace the original driver vector with the ordered one
237+ *drivers = orderedDrivers;
238+
239+ if (debugTraceEnabled) {
240+ std::string message = " driverOrdering: Drivers after ZEL_DRIVERS_ORDER:" ;
241+ for (uint32_t i = 0 ; i < drivers->size (); ++i) {
242+ message += " \n [" + std::to_string (i) + " ] Driver Type: " + std::to_string ((*drivers)[i].driverType ) + " Driver Name: " + (*drivers)[i].name ;
243+ }
244+ debug_trace_message (message, " " );
245+ }
246+ }
247+
75248 bool context_t::driverSorting (driver_vector_t *drivers, ze_init_driver_type_desc_t * desc, bool sysmanOnly) {
76249 ze_init_driver_type_desc_t permissiveDesc = {};
77250 permissiveDesc.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC;
@@ -246,6 +419,10 @@ namespace loader
246419 }
247420 debug_trace_message (message, " " );
248421 }
422+
423+ // Apply driver ordering based on ZEL_DRIVERS_ORDER environment variable
424+ driverOrdering (drivers);
425+
249426 return true ;
250427 }
251428
@@ -577,7 +754,7 @@ namespace loader
577754 GET_FUNCTION_PTR (validationLayer, " zelLoaderGetVersion" ));
578755 zel_component_version_t compVersion;
579756 if (getVersion && ZE_RESULT_SUCCESS == getVersion (&compVersion))
580- {
757+ {
581758 compVersions.push_back (compVersion);
582759 }
583760 } else if (debugTraceEnabled) {
@@ -602,7 +779,7 @@ namespace loader
602779 GET_FUNCTION_PTR (tracingLayer, " zelLoaderGetVersion" ));
603780 zel_component_version_t compVersion;
604781 if (getVersion && ZE_RESULT_SUCCESS == getVersion (&compVersion))
605- {
782+ {
606783 compVersions.push_back (compVersion);
607784 }
608785 } else if (debugTraceEnabled) {
0 commit comments