Skip to content

Commit 9526595

Browse files
fix: segfaults in zer DDI intercepts (#378)
* fix: segfaults in zer DDI intercepts Signed-off-by: Vishnu Khanth <vishnu.khanth.b@intel.com>
1 parent b8ab77b commit 9526595

File tree

5 files changed

+36
-7
lines changed

5 files changed

+36
-7
lines changed

scripts/templates/ldrddi.cpp.mako

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,13 @@ namespace loader
9898
%if not re.match(r"\w+InitDrivers$", th.make_func_name(n, tags, obj)):
9999
std::call_once(loader::context->coreDriverSortOnce, []() {
100100
loader::context->driverSorting(&loader::context->zeDrivers, nullptr, false);
101-
loader::context->defaultZerDriverHandle = &loader::context->zeDrivers.front().zerDriverHandle;
101+
loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
102102
loader::defaultZerDdiTable = &loader::context->zeDrivers.front().dditable.zer;
103103
});
104104
%else:
105105
std::call_once(loader::context->coreDriverSortOnce, [desc]() {
106106
loader::context->driverSorting(&loader::context->zeDrivers, desc, false);
107-
loader::context->defaultZerDriverHandle = &loader::context->zeDrivers.front().zerDriverHandle;
107+
loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
108108
loader::defaultZerDdiTable = &loader::context->zeDrivers.front().dditable.zer;
109109
});
110110
%endif
@@ -234,6 +234,10 @@ namespace loader
234234
}
235235
${obj['params'][1]['name']}[ driver_index ] = reinterpret_cast<${n}_driver_handle_t>(
236236
context->${n}_driver_factory.getInstance( ${obj['params'][1]['name']}[ driver_index ], &drv.dditable ) );
237+
if (drv.zerDriverHandle != nullptr) {
238+
drv.zerDriverHandle = reinterpret_cast<${n}_driver_handle_t>(
239+
context->${n}_driver_factory.getInstance( drv.zerDriverHandle, &drv.dditable ) );
240+
}
237241
} else if (drv.properties.flags & ZE_DRIVER_DDI_HANDLE_EXT_FLAG_DDI_HANDLE_EXT_SUPPORTED) {
238242
if (loader::context->debugTraceEnabled) {
239243
std::string message = "Driver DDI Handles Supported for " + drv.name;

scripts/templates/ze_loader_internal.h.mako

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ namespace loader
131131
bool instrumentationEnabled = false;
132132
dditable_t tracing_dditable = {};
133133
std::shared_ptr<Logger> zel_logger;
134-
ze_driver_handle_t* defaultZerDriverHandle = nullptr;
134+
ze_driver_handle_t defaultZerDriverHandle = nullptr;
135135
};
136136

137137
extern ze_handle_t* loaderDispatch;

source/loader/ze_ldrddi.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ namespace loader
6363
if (!loader::context->sortingInProgress.exchange(true) && !loader::context->instrumentationEnabled) {
6464
std::call_once(loader::context->coreDriverSortOnce, []() {
6565
loader::context->driverSorting(&loader::context->zeDrivers, nullptr, false);
66-
loader::context->defaultZerDriverHandle = &loader::context->zeDrivers.front().zerDriverHandle;
66+
loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
6767
loader::defaultZerDdiTable = &loader::context->zeDrivers.front().dditable.zer;
6868
});
6969
loader::context->sortingInProgress.store(false);
@@ -156,6 +156,10 @@ namespace loader
156156
}
157157
phDrivers[ driver_index ] = reinterpret_cast<ze_driver_handle_t>(
158158
context->ze_driver_factory.getInstance( phDrivers[ driver_index ], &drv.dditable ) );
159+
if (drv.zerDriverHandle != nullptr) {
160+
drv.zerDriverHandle = reinterpret_cast<ze_driver_handle_t>(
161+
context->ze_driver_factory.getInstance( drv.zerDriverHandle, &drv.dditable ) );
162+
}
159163
} else if (drv.properties.flags & ZE_DRIVER_DDI_HANDLE_EXT_FLAG_DDI_HANDLE_EXT_SUPPORTED) {
160164
if (loader::context->debugTraceEnabled) {
161165
std::string message = "Driver DDI Handles Supported for " + drv.name;
@@ -208,7 +212,7 @@ namespace loader
208212
if (!loader::context->sortingInProgress.exchange(true) && !loader::context->instrumentationEnabled) {
209213
std::call_once(loader::context->coreDriverSortOnce, [desc]() {
210214
loader::context->driverSorting(&loader::context->zeDrivers, desc, false);
211-
loader::context->defaultZerDriverHandle = &loader::context->zeDrivers.front().zerDriverHandle;
215+
loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
212216
loader::defaultZerDdiTable = &loader::context->zeDrivers.front().dditable.zer;
213217
});
214218
loader::context->sortingInProgress.store(false);
@@ -303,6 +307,10 @@ namespace loader
303307
}
304308
phDrivers[ driver_index ] = reinterpret_cast<ze_driver_handle_t>(
305309
context->ze_driver_factory.getInstance( phDrivers[ driver_index ], &drv.dditable ) );
310+
if (drv.zerDriverHandle != nullptr) {
311+
drv.zerDriverHandle = reinterpret_cast<ze_driver_handle_t>(
312+
context->ze_driver_factory.getInstance( drv.zerDriverHandle, &drv.dditable ) );
313+
}
306314
} else if (drv.properties.flags & ZE_DRIVER_DDI_HANDLE_EXT_FLAG_DDI_HANDLE_EXT_SUPPORTED) {
307315
if (loader::context->debugTraceEnabled) {
308316
std::string message = "Driver DDI Handles Supported for " + drv.name;

source/loader/ze_loader.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,23 @@ namespace loader
374374
continue;
375375
}
376376
driver.driverDDIHandleSupportQueried = true;
377+
378+
if (!(driver.properties.flags & ZE_DRIVER_DDI_HANDLE_EXT_FLAG_DDI_HANDLE_EXT_SUPPORTED) || !loader::context->driverDDIPathDefault) {
379+
if (debugTraceEnabled) {
380+
std::string message = "driverSorting: Driver DDI Handles Not Supported for " + driver.name;
381+
debug_trace_message(message, "");
382+
}
383+
if (driver.zerDriverHandle != nullptr) {
384+
driver.zerDriverHandle = reinterpret_cast<ze_driver_handle_t>(
385+
loader::context->ze_driver_factory.getInstance(driver.zerDriverHandle, &driver.dditable));
386+
}
387+
} else {
388+
if (debugTraceEnabled) {
389+
std::string message = "driverSorting: Driver DDI Handles Supported for " + driver.name;
390+
debug_trace_message(message, "");
391+
}
392+
}
393+
377394
uint32_t deviceCount = 0;
378395
res = driver.dditable.ze.Device.pfnGet( handle, &deviceCount, nullptr );
379396
if( ZE_RESULT_SUCCESS != res ) {
@@ -551,7 +568,7 @@ namespace loader
551568
return ZE_RESULT_ERROR_UNINITIALIZED;
552569

553570
// Set default driver handle and DDI table to the first driver in the list before sorting.
554-
loader::context->defaultZerDriverHandle = &loader::context->zeDrivers.front().zerDriverHandle;
571+
loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
555572
loader::defaultZerDdiTable = &loader::context->zeDrivers.front().dditable.zer;
556573
return ZE_RESULT_SUCCESS;
557574
}

source/loader/ze_loader_internal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ namespace loader
167167
bool instrumentationEnabled = false;
168168
dditable_t tracing_dditable = {};
169169
std::shared_ptr<Logger> zel_logger;
170-
ze_driver_handle_t* defaultZerDriverHandle = nullptr;
170+
ze_driver_handle_t defaultZerDriverHandle = nullptr;
171171
};
172172

173173
extern ze_handle_t* loaderDispatch;

0 commit comments

Comments
 (0)