@@ -63,7 +63,6 @@ namespace loader
6363 if (!loader::context->sortingInProgress.exchange(true) && !loader::context->instrumentationEnabled) {
6464 std::call_once(loader::context->coreDriverSortOnce, []() {
6565 loader::context->driverSorting(&loader::context->zeDrivers, nullptr, false);
66- loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
6766 loader::defaultZerDdiTable = &loader::context->zeDrivers.front().dditable.zer;
6867 });
6968 loader::context->sortingInProgress.store(false);
@@ -128,6 +127,7 @@ namespace loader
128127 if (strcmp(extensionProperties[extIndex].name, ZE_DRIVER_DDI_HANDLES_EXT_NAME) == 0 && (!(extensionProperties[extIndex].version >= ZE_DRIVER_DDI_HANDLES_EXT_VERSION_1_1))) {
129128 // Driver supports DDI Handles but not the required version for ZER APIs so set the driverHandle to nullptr
130129 drv.zerDriverHandle = nullptr;
130+ drv.zerDriverDDISupported = false;
131131 break;
132132 }
133133 }
@@ -157,8 +157,7 @@ namespace loader
157157 phDrivers[ driver_index ] = reinterpret_cast<ze_driver_handle_t>(
158158 context->ze_driver_factory.getInstance( phDrivers[ driver_index ], &drv.dditable ) );
159159 if (drv.zerDriverHandle != nullptr) {
160- drv.zerDriverHandle = reinterpret_cast<ze_driver_handle_t>(
161- context->ze_driver_factory.getInstance( drv.zerDriverHandle, &drv.dditable ) );
160+ drv.zerDriverHandle = phDrivers[ driver_index ];
162161 }
163162 } else if (drv.properties.flags & ZE_DRIVER_DDI_HANDLE_EXT_FLAG_DDI_HANDLE_EXT_SUPPORTED) {
164163 if (loader::context->debugTraceEnabled) {
@@ -183,6 +182,10 @@ namespace loader
183182 if (total_driver_handle_count > 0) {
184183 result = ZE_RESULT_SUCCESS;
185184 }
185+ if (loader::context->zeDrivers.front().zerDriverDDISupported)
186+ loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
187+ else
188+ loader::context->defaultZerDriverHandle = nullptr;
186189
187190 return result;
188191 }
@@ -212,7 +215,6 @@ namespace loader
212215 if (!loader::context->sortingInProgress.exchange(true) && !loader::context->instrumentationEnabled) {
213216 std::call_once(loader::context->coreDriverSortOnce, [desc]() {
214217 loader::context->driverSorting(&loader::context->zeDrivers, desc, false);
215- loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
216218 loader::defaultZerDdiTable = &loader::context->zeDrivers.front().dditable.zer;
217219 });
218220 loader::context->sortingInProgress.store(false);
@@ -279,6 +281,7 @@ namespace loader
279281 if (strcmp(extensionProperties[extIndex].name, ZE_DRIVER_DDI_HANDLES_EXT_NAME) == 0 && (!(extensionProperties[extIndex].version >= ZE_DRIVER_DDI_HANDLES_EXT_VERSION_1_1))) {
280282 // Driver supports DDI Handles but not the required version for ZER APIs so set the driverHandle to nullptr
281283 drv.zerDriverHandle = nullptr;
284+ drv.zerDriverDDISupported = false;
282285 break;
283286 }
284287 }
@@ -308,8 +311,7 @@ namespace loader
308311 phDrivers[ driver_index ] = reinterpret_cast<ze_driver_handle_t>(
309312 context->ze_driver_factory.getInstance( phDrivers[ driver_index ], &drv.dditable ) );
310313 if (drv.zerDriverHandle != nullptr) {
311- drv.zerDriverHandle = reinterpret_cast<ze_driver_handle_t>(
312- context->ze_driver_factory.getInstance( drv.zerDriverHandle, &drv.dditable ) );
314+ drv.zerDriverHandle = phDrivers[ driver_index ];
313315 }
314316 } else if (drv.properties.flags & ZE_DRIVER_DDI_HANDLE_EXT_FLAG_DDI_HANDLE_EXT_SUPPORTED) {
315317 if (loader::context->debugTraceEnabled) {
@@ -334,6 +336,10 @@ namespace loader
334336 if (total_driver_handle_count > 0) {
335337 result = ZE_RESULT_SUCCESS;
336338 }
339+ if (loader::context->zeDrivers.front().zerDriverDDISupported)
340+ loader::context->defaultZerDriverHandle = loader::context->zeDrivers.front().zerDriverHandle;
341+ else
342+ loader::context->defaultZerDriverHandle = nullptr;
337343
338344 return result;
339345 }
0 commit comments