Skip to content

Commit d27aa28

Browse files
authored
Fix to zer Handle init to address replacement of ze driver handles (#385)
- Removed generation of zerDriver Handles and reuse the ze driver handles when running in legacy ddi table modes. - Fixed ULT for driverget to properly read handles. Signed-off-by: Neil R. Spruit <[email protected]>
1 parent 26dcdb5 commit d27aa28

File tree

7 files changed

+33
-18
lines changed

7 files changed

+33
-18
lines changed

scripts/templates/ldrddi.cpp.mako

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,11 @@ namespace loader
9898
%if not re.match(r"\w+InitDrivers$", th.make_func_name(n, tags, obj)):
9999
std::call_once(loader::context->coreDriverSortOnce, []() {
100100
loader::context->driverSorting(&loader::context->zeDrivers, nullptr, false);
101-
loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
102101
loader::defaultZerDdiTable = &loader::context->zeDrivers.front().dditable.zer;
103102
});
104103
%else:
105104
std::call_once(loader::context->coreDriverSortOnce, [desc]() {
106105
loader::context->driverSorting(&loader::context->zeDrivers, desc, false);
107-
loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
108106
loader::defaultZerDdiTable = &loader::context->zeDrivers.front().dditable.zer;
109107
});
110108
%endif
@@ -180,8 +178,8 @@ namespace loader
180178
{
181179
for( uint32_t i = 0; i < library_driver_handle_count; ++i ) {
182180
uint32_t driver_index = total_driver_handle_count + i;
183-
drv.zerDriverHandle = phDrivers[ driver_index ];
184181
%if namespace != "zes":
182+
drv.zerDriverHandle = phDrivers[ driver_index ];
185183
if (drv.driverDDIHandleSupportQueried == false) {
186184
uint32_t extensionCount = 0;
187185
ze_result_t res = drv.dditable.ze.Driver.pfnGetExtensionProperties(phDrivers[ driver_index ], &extensionCount, nullptr);
@@ -206,6 +204,7 @@ namespace loader
206204
if (strcmp(extensionProperties[extIndex].name, ZE_DRIVER_DDI_HANDLES_EXT_NAME) == 0 && (!(extensionProperties[extIndex].version >= ZE_DRIVER_DDI_HANDLES_EXT_VERSION_1_1))) {
207205
// Driver supports DDI Handles but not the required version for ZER APIs so set the driverHandle to nullptr
208206
drv.zerDriverHandle = nullptr;
207+
drv.zerDriverDDISupported = false;
209208
break;
210209
}
211210
}
@@ -235,8 +234,7 @@ namespace loader
235234
${obj['params'][1]['name']}[ driver_index ] = reinterpret_cast<${n}_driver_handle_t>(
236235
context->${n}_driver_factory.getInstance( ${obj['params'][1]['name']}[ driver_index ], &drv.dditable ) );
237236
if (drv.zerDriverHandle != nullptr) {
238-
drv.zerDriverHandle = reinterpret_cast<${n}_driver_handle_t>(
239-
context->${n}_driver_factory.getInstance( drv.zerDriverHandle, &drv.dditable ) );
237+
drv.zerDriverHandle = ${obj['params'][1]['name']}[ driver_index ];
240238
}
241239
} else if (drv.properties.flags & ZE_DRIVER_DDI_HANDLE_EXT_FLAG_DDI_HANDLE_EXT_SUPPORTED) {
242240
if (loader::context->debugTraceEnabled) {
@@ -265,7 +263,13 @@ namespace loader
265263
if (total_driver_handle_count > 0) {
266264
result = ${X}_RESULT_SUCCESS;
267265
}
266+
%if namespace != "zes":
267+
if (loader::context->zeDrivers.front().zerDriverDDISupported)
268+
loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
269+
else
270+
loader::context->defaultZerDriverHandle = nullptr;
268271

272+
%endif
269273
%else:
270274
%for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)):
271275
%if 0 == i:

scripts/templates/ze_loader_internal.h.mako

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ namespace loader
6969
bool legacyInitAttempted = false;
7070
bool driverDDIHandleSupportQueried = false;
7171
ze_driver_handle_t zerDriverHandle = nullptr;
72+
bool zerDriverDDISupported = true;
7273
};
7374

7475
using driver_vector_t = std::vector< driver_t >;

source/loader/ze_ldrddi.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ namespace loader
6363
if (!loader::context->sortingInProgress.exchange(true) && !loader::context->instrumentationEnabled) {
6464
std::call_once(loader::context->coreDriverSortOnce, []() {
6565
loader::context->driverSorting(&loader::context->zeDrivers, nullptr, false);
66-
loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
6766
loader::defaultZerDdiTable = &loader::context->zeDrivers.front().dditable.zer;
6867
});
6968
loader::context->sortingInProgress.store(false);
@@ -128,6 +127,7 @@ namespace loader
128127
if (strcmp(extensionProperties[extIndex].name, ZE_DRIVER_DDI_HANDLES_EXT_NAME) == 0 && (!(extensionProperties[extIndex].version >= ZE_DRIVER_DDI_HANDLES_EXT_VERSION_1_1))) {
129128
// Driver supports DDI Handles but not the required version for ZER APIs so set the driverHandle to nullptr
130129
drv.zerDriverHandle = nullptr;
130+
drv.zerDriverDDISupported = false;
131131
break;
132132
}
133133
}
@@ -157,8 +157,7 @@ namespace loader
157157
phDrivers[ driver_index ] = reinterpret_cast<ze_driver_handle_t>(
158158
context->ze_driver_factory.getInstance( phDrivers[ driver_index ], &drv.dditable ) );
159159
if (drv.zerDriverHandle != nullptr) {
160-
drv.zerDriverHandle = reinterpret_cast<ze_driver_handle_t>(
161-
context->ze_driver_factory.getInstance( drv.zerDriverHandle, &drv.dditable ) );
160+
drv.zerDriverHandle = phDrivers[ driver_index ];
162161
}
163162
} else if (drv.properties.flags & ZE_DRIVER_DDI_HANDLE_EXT_FLAG_DDI_HANDLE_EXT_SUPPORTED) {
164163
if (loader::context->debugTraceEnabled) {
@@ -183,6 +182,10 @@ namespace loader
183182
if (total_driver_handle_count > 0) {
184183
result = ZE_RESULT_SUCCESS;
185184
}
185+
if (loader::context->zeDrivers.front().zerDriverDDISupported)
186+
loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
187+
else
188+
loader::context->defaultZerDriverHandle = nullptr;
186189

187190
return result;
188191
}
@@ -212,7 +215,6 @@ namespace loader
212215
if (!loader::context->sortingInProgress.exchange(true) && !loader::context->instrumentationEnabled) {
213216
std::call_once(loader::context->coreDriverSortOnce, [desc]() {
214217
loader::context->driverSorting(&loader::context->zeDrivers, desc, false);
215-
loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
216218
loader::defaultZerDdiTable = &loader::context->zeDrivers.front().dditable.zer;
217219
});
218220
loader::context->sortingInProgress.store(false);
@@ -279,6 +281,7 @@ namespace loader
279281
if (strcmp(extensionProperties[extIndex].name, ZE_DRIVER_DDI_HANDLES_EXT_NAME) == 0 && (!(extensionProperties[extIndex].version >= ZE_DRIVER_DDI_HANDLES_EXT_VERSION_1_1))) {
280282
// Driver supports DDI Handles but not the required version for ZER APIs so set the driverHandle to nullptr
281283
drv.zerDriverHandle = nullptr;
284+
drv.zerDriverDDISupported = false;
282285
break;
283286
}
284287
}
@@ -308,8 +311,7 @@ namespace loader
308311
phDrivers[ driver_index ] = reinterpret_cast<ze_driver_handle_t>(
309312
context->ze_driver_factory.getInstance( phDrivers[ driver_index ], &drv.dditable ) );
310313
if (drv.zerDriverHandle != nullptr) {
311-
drv.zerDriverHandle = reinterpret_cast<ze_driver_handle_t>(
312-
context->ze_driver_factory.getInstance( drv.zerDriverHandle, &drv.dditable ) );
314+
drv.zerDriverHandle = phDrivers[ driver_index ];
313315
}
314316
} else if (drv.properties.flags & ZE_DRIVER_DDI_HANDLE_EXT_FLAG_DDI_HANDLE_EXT_SUPPORTED) {
315317
if (loader::context->debugTraceEnabled) {
@@ -334,6 +336,10 @@ namespace loader
334336
if (total_driver_handle_count > 0) {
335337
result = ZE_RESULT_SUCCESS;
336338
}
339+
if (loader::context->zeDrivers.front().zerDriverDDISupported)
340+
loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
341+
else
342+
loader::context->defaultZerDriverHandle = nullptr;
337343

338344
return result;
339345
}

source/loader/ze_loader.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,11 @@ namespace loader
355355
if (strcmp(extensionProperties[extIndex].name, ZE_DRIVER_DDI_HANDLES_EXT_NAME) == 0 && (!(extensionProperties[extIndex].version >= ZE_DRIVER_DDI_HANDLES_EXT_VERSION_1_1))) {
356356
// Driver supports DDI Handles but not the required version for ZER APIs so set the driverHandle to nullptr
357357
driver.zerDriverHandle = nullptr;
358+
driver.zerDriverDDISupported = false;
358359
break;
359360
}
360361
}
362+
361363
}
362364
driver.properties = {};
363365
driver.properties.stype = ZE_STRUCTURE_TYPE_DRIVER_DDI_HANDLES_EXT_PROPERTIES;
@@ -380,10 +382,6 @@ namespace loader
380382
std::string message = "driverSorting: Driver DDI Handles Not Supported for " + driver.name;
381383
debug_trace_message(message, "");
382384
}
383-
if (driver.zerDriverHandle != nullptr) {
384-
driver.zerDriverHandle = reinterpret_cast<ze_driver_handle_t>(
385-
loader::context->ze_driver_factory.getInstance(driver.zerDriverHandle, &driver.dditable));
386-
}
387385
} else {
388386
if (debugTraceEnabled) {
389387
std::string message = "driverSorting: Driver DDI Handles Supported for " + driver.name;
@@ -568,7 +566,10 @@ namespace loader
568566
return ZE_RESULT_ERROR_UNINITIALIZED;
569567

570568
// Set default driver handle and DDI table to the first driver in the list before sorting.
571-
loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
569+
if (loader::context->zeDrivers.front().zerDriverDDISupported)
570+
loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
571+
else
572+
loader::context->defaultZerDriverHandle = nullptr;
572573
loader::defaultZerDdiTable = &loader::context->zeDrivers.front().dditable.zer;
573574
return ZE_RESULT_SUCCESS;
574575
}

source/loader/ze_loader_internal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ namespace loader
6060
bool legacyInitAttempted = false;
6161
bool driverDDIHandleSupportQueried = false;
6262
ze_driver_handle_t zerDriverHandle = nullptr;
63+
bool zerDriverDDISupported = true;
6364
};
6465

6566
using driver_vector_t = std::vector< driver_t >;

source/loader/zes_ldrddi.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ namespace loader
105105
{
106106
for( uint32_t i = 0; i < library_driver_handle_count; ++i ) {
107107
uint32_t driver_index = total_driver_handle_count + i;
108-
drv.zerDriverHandle = phDrivers[ driver_index ];
109108
phDrivers[ driver_index ] = reinterpret_cast<zes_driver_handle_t>(
110109
context->zes_driver_factory.getInstance( phDrivers[ driver_index ], &drv.dditable ) );
111110
}
@@ -125,7 +124,6 @@ namespace loader
125124
if (total_driver_handle_count > 0) {
126125
result = ZE_RESULT_SUCCESS;
127126
}
128-
129127
return result;
130128
}
131129

test/loader_api.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2400,6 +2400,10 @@ TEST_F(DriverOrderingTest,
24002400
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(0));
24012401
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGet(&driverGetCount, nullptr));
24022402
EXPECT_GT(driverGetCount, 0);
2403+
std::vector<ze_driver_handle_t> drivers;
2404+
drivers.resize(driverGetCount);
2405+
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGet(&driverGetCount, drivers.data()));
2406+
EXPECT_GT(driverGetCount, 0);
24032407

24042408
const char *errorString = nullptr;
24052409
uint32_t deviceId = 0;

0 commit comments

Comments
 (0)