1414#include <umf/memory_provider_ops.h>
1515#include <umf/providers/provider_level_zero.h>
1616
17+ #include "base_alloc_global.h"
1718#include "provider_level_zero_internal.h"
1819#include "utils_load_library.h"
1920#include "utils_log.h"
2021
22+ typedef struct _ze_driver_handle_t * ze_driver_handle_t ;
23+
24+ static ze_driver_handle_t * zeAllDrivers = NULL ;
2125static void * ze_lib_handle = NULL ;
2226
2327void fini_ze_global_state (void ) {
2428 if (ze_lib_handle ) {
29+ umf_ba_global_free (zeAllDrivers );
2530 utils_close_library (ze_lib_handle );
2631 ze_lib_handle = NULL ;
2732 }
@@ -111,7 +116,6 @@ umf_memory_provider_ops_t *umfLevelZeroMemoryProviderOps(void) {
111116
112117#else // !defined(UMF_NO_LEVEL_ZERO_PROVIDER)
113118
114- #include "base_alloc_global.h"
115119#include "libumf.h"
116120#include "utils_assert.h"
117121#include "utils_common.h"
@@ -158,6 +162,8 @@ typedef struct ze_memory_provider_t {
158162} ze_memory_provider_t ;
159163
160164typedef struct ze_ops_t {
165+ ze_result_t (* zeInitDrivers )(uint32_t * , ze_driver_handle_t * ,
166+ ze_init_driver_type_desc_t * );
161167 ze_result_t (* zeMemAllocHost )(ze_context_handle_t ,
162168 const ze_host_mem_alloc_desc_t * , size_t ,
163169 size_t , void * );
@@ -211,6 +217,28 @@ static umf_result_t ze2umf_result(ze_result_t result) {
211217 }
212218}
213219
220+ static umf_result_t ze_init_drivers () {
221+ ze_init_driver_type_desc_t desc = {
222+ .stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC ,
223+ .pNext = NULL ,
224+ .flags = UINT32_MAX };
225+ uint32_t driverCount = 0 ;
226+ ze_result_t result = g_ze_ops .zeInitDrivers (& driverCount , NULL , & desc );
227+ if (result != ZE_RESULT_SUCCESS ) {
228+ return ze2umf_result (result );
229+ }
230+
231+ assert (zeAllDrivers == NULL );
232+ zeAllDrivers =
233+ umf_ba_global_alloc (sizeof (ze_driver_handle_t ) * driverCount );
234+ result = g_ze_ops .zeInitDrivers (& driverCount , zeAllDrivers , & desc );
235+ if (result != ZE_RESULT_SUCCESS ) {
236+ return ze2umf_result (result );
237+ }
238+
239+ return UMF_RESULT_SUCCESS ;
240+ }
241+
214242static void init_ze_global_state (void ) {
215243#ifdef _WIN32
216244 const char * lib_name = "ze_loader.dll" ;
@@ -228,6 +256,8 @@ static void init_ze_global_state(void) {
228256 return ;
229257 }
230258
259+ * (void * * )& g_ze_ops .zeInitDrivers =
260+ utils_get_symbol_addr (lib_handle , "zeInitDrivers" , lib_name );
231261 * (void * * )& g_ze_ops .zeMemAllocHost =
232262 utils_get_symbol_addr (lib_handle , "zeMemAllocHost" , lib_name );
233263 * (void * * )& g_ze_ops .zeMemAllocDevice =
@@ -253,10 +283,10 @@ static void init_ze_global_state(void) {
253283 * (void * * )& g_ze_ops .zeMemGetAllocProperties =
254284 utils_get_symbol_addr (lib_handle , "zeMemGetAllocProperties" , lib_name );
255285
256- if (!g_ze_ops .zeMemAllocHost || !g_ze_ops .zeMemAllocDevice ||
257- !g_ze_ops .zeMemAllocShared || !g_ze_ops .zeMemFree ||
258- !g_ze_ops .zeMemGetIpcHandle || !g_ze_ops .zeMemOpenIpcHandle ||
259- !g_ze_ops .zeMemCloseIpcHandle ||
286+ if (!g_ze_ops .zeInitDrivers || !g_ze_ops .zeMemAllocHost ||
287+ !g_ze_ops .zeMemAllocDevice || !g_ze_ops .zeMemAllocShared ||
288+ !g_ze_ops .zeMemFree || !g_ze_ops .zeMemGetIpcHandle ||
289+ !g_ze_ops .zeMemOpenIpcHandle || ! g_ze_ops . zeMemCloseIpcHandle ||
260290 !g_ze_ops .zeContextMakeMemoryResident ||
261291 !g_ze_ops .zeDeviceGetProperties || !g_ze_ops .zeMemGetAllocProperties ) {
262292 // g_ze_ops.zeMemPutIpcHandle can be NULL because it was introduced
@@ -267,6 +297,15 @@ static void init_ze_global_state(void) {
267297 return ;
268298 }
269299 ze_lib_handle = lib_handle ;
300+
301+ umf_result_t result = ze_init_drivers ();
302+ if (result != UMF_RESULT_SUCCESS ) {
303+ LOG_FATAL ("Failed to initialize Level Zero drivers" );
304+ Init_ze_global_state_failed = true;
305+ utils_close_library (lib_handle );
306+ lib_handle = NULL ;
307+ return ;
308+ }
270309}
271310
272311umf_result_t umfLevelZeroMemoryProviderParamsCreate (
0 commit comments