Skip to content

Commit 05d5f1a

Browse files
committed
Call zeInitDrivers in L0 provider
According to the L0 spec, zeInitDrivers must be called (by every library) before calling any other APIs. Not calling zeInitDrivers causes crash when using statically linked L0 loader in UR.
1 parent 8a619a4 commit 05d5f1a

File tree

1 file changed

+44
-5
lines changed

1 file changed

+44
-5
lines changed

src/provider/provider_level_zero.c

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@
1414
#include <umf/memory_provider_ops.h>
1515
#include <umf/providers/provider_level_zero.h>
1616

17+
#include "base_alloc_global.h"
1718
#include "provider_level_zero_internal.h"
1819
#include "utils_load_library.h"
1920
#include "utils_log.h"
2021

22+
typedef struct _ze_driver_handle_t *ze_driver_handle_t;
23+
24+
static ze_driver_handle_t *zeAllDrivers = NULL;
2125
static void *ze_lib_handle = NULL;
2226

2327
void fini_ze_global_state(void) {
2428
if (ze_lib_handle) {
29+
umf_ba_global_free(zeAllDrivers);
2530
utils_close_library(ze_lib_handle);
2631
ze_lib_handle = NULL;
2732
}
@@ -111,7 +116,6 @@ umf_memory_provider_ops_t *umfLevelZeroMemoryProviderOps(void) {
111116

112117
#else // !defined(UMF_NO_LEVEL_ZERO_PROVIDER)
113118

114-
#include "base_alloc_global.h"
115119
#include "libumf.h"
116120
#include "utils_assert.h"
117121
#include "utils_common.h"
@@ -158,6 +162,8 @@ typedef struct ze_memory_provider_t {
158162
} ze_memory_provider_t;
159163

160164
typedef struct ze_ops_t {
165+
ze_result_t (*zeInitDrivers)(uint32_t *, ze_driver_handle_t *,
166+
ze_init_driver_type_desc_t *);
161167
ze_result_t (*zeMemAllocHost)(ze_context_handle_t,
162168
const ze_host_mem_alloc_desc_t *, size_t,
163169
size_t, void *);
@@ -211,6 +217,28 @@ static umf_result_t ze2umf_result(ze_result_t result) {
211217
}
212218
}
213219

220+
static umf_result_t ze_init_drivers() {
221+
ze_init_driver_type_desc_t desc = {
222+
.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC,
223+
.pNext = NULL,
224+
.flags = UINT32_MAX};
225+
uint32_t driverCount = 0;
226+
ze_result_t result = g_ze_ops.zeInitDrivers(&driverCount, NULL, &desc);
227+
if (result != ZE_RESULT_SUCCESS) {
228+
return ze2umf_result(result);
229+
}
230+
231+
assert(zeAllDrivers == NULL);
232+
zeAllDrivers =
233+
umf_ba_global_alloc(sizeof(ze_driver_handle_t) * driverCount);
234+
result = g_ze_ops.zeInitDrivers(&driverCount, zeAllDrivers, &desc);
235+
if (result != ZE_RESULT_SUCCESS) {
236+
return ze2umf_result(result);
237+
}
238+
239+
return UMF_RESULT_SUCCESS;
240+
}
241+
214242
static void init_ze_global_state(void) {
215243
#ifdef _WIN32
216244
const char *lib_name = "ze_loader.dll";
@@ -228,6 +256,8 @@ static void init_ze_global_state(void) {
228256
return;
229257
}
230258

259+
*(void **)&g_ze_ops.zeInitDrivers =
260+
utils_get_symbol_addr(lib_handle, "zeInitDrivers", lib_name);
231261
*(void **)&g_ze_ops.zeMemAllocHost =
232262
utils_get_symbol_addr(lib_handle, "zeMemAllocHost", lib_name);
233263
*(void **)&g_ze_ops.zeMemAllocDevice =
@@ -253,10 +283,10 @@ static void init_ze_global_state(void) {
253283
*(void **)&g_ze_ops.zeMemGetAllocProperties =
254284
utils_get_symbol_addr(lib_handle, "zeMemGetAllocProperties", lib_name);
255285

256-
if (!g_ze_ops.zeMemAllocHost || !g_ze_ops.zeMemAllocDevice ||
257-
!g_ze_ops.zeMemAllocShared || !g_ze_ops.zeMemFree ||
258-
!g_ze_ops.zeMemGetIpcHandle || !g_ze_ops.zeMemOpenIpcHandle ||
259-
!g_ze_ops.zeMemCloseIpcHandle ||
286+
if (!g_ze_ops.zeInitDrivers || !g_ze_ops.zeMemAllocHost ||
287+
!g_ze_ops.zeMemAllocDevice || !g_ze_ops.zeMemAllocShared ||
288+
!g_ze_ops.zeMemFree || !g_ze_ops.zeMemGetIpcHandle ||
289+
!g_ze_ops.zeMemOpenIpcHandle || !g_ze_ops.zeMemCloseIpcHandle ||
260290
!g_ze_ops.zeContextMakeMemoryResident ||
261291
!g_ze_ops.zeDeviceGetProperties || !g_ze_ops.zeMemGetAllocProperties) {
262292
// g_ze_ops.zeMemPutIpcHandle can be NULL because it was introduced
@@ -267,6 +297,15 @@ static void init_ze_global_state(void) {
267297
return;
268298
}
269299
ze_lib_handle = lib_handle;
300+
301+
umf_result_t result = ze_init_drivers();
302+
if (result != UMF_RESULT_SUCCESS) {
303+
LOG_FATAL("Failed to initialize Level Zero drivers");
304+
Init_ze_global_state_failed = true;
305+
utils_close_library(lib_handle);
306+
lib_handle = NULL;
307+
return;
308+
}
270309
}
271310

272311
umf_result_t umfLevelZeroMemoryProviderParamsCreate(

0 commit comments

Comments
 (0)