Skip to content

Commit 928a993

Browse files
authored
Fix Dynamic generation of code for _factory globals (#294)
* Fix Dynamic generation of code for _factory globals * Rename internal loader header template and update comment Signed-off-by: Neil R. Spruit <neil.r.spruit@intel.com>
1 parent d1baac0 commit 928a993

File tree

4 files changed

+280
-39
lines changed

4 files changed

+280
-39
lines changed

scripts/generate_code.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,23 @@ def _mako_loader_cpp(path, namespace, tags, version, specs, meta):
107107
filename = "%s.h"%(name)
108108
fout = os.path.join(path, filename)
109109

110+
print("Generating %s..."%fout)
111+
loc += util.makoWrite(
112+
fin, fout,
113+
name=name,
114+
ver=version,
115+
namespace=namespace,
116+
tags=tags,
117+
specs=specs,
118+
meta=meta)
119+
120+
template = "ze_loader_internal.h.mako"
121+
fin = os.path.join(templates_dir, template)
122+
123+
name = "%s_loader_internal_tmp"%(namespace)
124+
filename = "%s.h"%(name)
125+
fout = os.path.join(path, filename)
126+
110127
print("Generating %s..."%fout)
111128
loc += util.makoWrite(
112129
fin, fout,

scripts/json2src.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,106 @@ def add_argument(parser, name, help, default=False):
5252
if args.drivers:
5353
generate_code.generate_drivers(srcpath, config['name'], config['namespace'], config['tags'], args.ver, specs, input['meta'])
5454

55+
def merge_header_files(input_files, output_file):
56+
"""
57+
Merges the unique content of multiple header files into a single output file.
58+
59+
Args:
60+
input_files: A list of paths to the input header files.
61+
output_file: The path to the output file.
62+
"""
63+
try:
64+
unique_lines = set()
65+
for infile_path in input_files:
66+
try:
67+
with open(infile_path, 'r') as infile:
68+
inside_factory = False
69+
for line in infile:
70+
if '/// factories' in line:
71+
inside_factory = True
72+
continue
73+
if '/// end factories' in line:
74+
inside_factory = False
75+
continue
76+
if inside_factory:
77+
unique_lines.add(line)
78+
except FileNotFoundError:
79+
print(f"Error: Input file not found: {infile_path}")
80+
return
81+
82+
with open(output_file, 'w') as outfile:
83+
for line in sorted(unique_lines):
84+
outfile.write(line)
85+
print(f"Successfully merged unique header file content into: {output_file}")
86+
except Exception as e:
87+
print(f"An error occurred: {e}")
88+
89+
header_files = [
90+
'source/loader/ze_loader_internal_tmp.h',
91+
'source/loader/zet_loader_internal_tmp.h',
92+
'source/loader/zes_loader_internal_tmp.h'
93+
]
94+
output_file = 'source/loader/ze_loader_internal_factories.h'
95+
merge_header_files(header_files, output_file)
96+
def replace_factory_section(input_file, factory_file, output_file):
97+
"""
98+
Replaces the content between '/// factory' and '/// end factory' in the input file
99+
with the content from the factory file and writes the result to the output file.
100+
101+
Args:
102+
input_file: The path to the input file.
103+
factory_file: The path to the factory file.
104+
output_file: The path to the output file.
105+
"""
106+
try:
107+
with open(input_file, 'r') as infile:
108+
lines = infile.readlines()
109+
110+
with open(factory_file, 'r') as factory:
111+
factory_lines = factory.readlines()
112+
113+
output_lines = []
114+
inside_factory = False
115+
116+
for line in lines:
117+
if '/// factories' in line:
118+
inside_factory = True
119+
output_lines.append(line)
120+
output_lines.extend(factory_lines)
121+
elif '/// end factories' in line:
122+
inside_factory = False
123+
if not inside_factory:
124+
output_lines.append(line)
125+
126+
with open(output_file, 'w') as outfile:
127+
outfile.writelines(output_lines)
128+
129+
print(f"Successfully replaced factory section in: {output_file}")
130+
except Exception as e:
131+
print(f"An error occurred: {e}")
132+
133+
input_file = 'source/loader/ze_loader_internal_tmp.h'
134+
factory_file = 'source/loader/ze_loader_internal_factories.h'
135+
output_file = 'source/loader/ze_loader_internal.h'
136+
replace_factory_section(input_file, factory_file, output_file)
137+
138+
# Delete temporary and factory files
139+
files_to_delete = [
140+
'source/loader/ze_loader_internal_tmp.h',
141+
'source/loader/zet_loader_internal_tmp.h',
142+
'source/loader/zes_loader_internal_tmp.h',
143+
'source/loader/ze_loader_internal_factories.h'
144+
]
145+
146+
for file_path in files_to_delete:
147+
try:
148+
os.remove(file_path)
149+
print(f"Deleted file: {file_path}")
150+
except FileNotFoundError:
151+
print(f"File not found, could not delete: {file_path}")
152+
except Exception as e:
153+
print(f"An error occurred while deleting {file_path}: {e}")
154+
55155
if args.debug:
56156
util.makoFileListWrite("generated.json")
57157

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
<%!
2+
import re
3+
from templates import helper as th
4+
%><%
5+
n=namespace
6+
N=n.upper()
7+
8+
x=tags['$x']
9+
X=x.upper()
10+
%>/*
11+
*
12+
* Copyright (C) 2019-2025 Intel Corporation
13+
*
14+
* SPDX-License-Identifier: MIT
15+
*
16+
* @file ze_loader_internal.h generated from ze_loader_internal.h.mako
17+
*
18+
*/
19+
#pragma once
20+
#include <vector>
21+
#include <map>
22+
#include <atomic>
23+
24+
#include "ze_ddi.h"
25+
#include "zet_ddi.h"
26+
#include "zes_ddi.h"
27+
28+
#include "ze_util.h"
29+
#include "ze_object.h"
30+
31+
#include "ze_ldrddi.h"
32+
#include "zet_ldrddi.h"
33+
#include "zes_ldrddi.h"
34+
35+
#include "loader/ze_loader.h"
36+
#include "../utils/logging.h"
37+
#include "spdlog/spdlog.h"
38+
namespace loader
39+
{
40+
///////////////////////////////////////////////////////////////////////////////
41+
/// @brief Driver Type Enumerations
42+
/// @details The ordering of the drivers reported to the user is based on the order of the enumerations provided.
43+
/// When additional driver types are added, they should be added to the end of the list to avoid reporting new device types
44+
/// before known device types.
45+
typedef enum _zel_driver_type_t
46+
{
47+
ZEL_DRIVER_TYPE_DISCRETE_GPU= 0, ///< The driver has Discrete GPUs only
48+
ZEL_DRIVER_TYPE_GPU = 1, ///< The driver has Heterogenous GPU types
49+
ZEL_DRIVER_TYPE_INTEGRATED_GPU = 2, ///< The driver has Integrated GPUs only
50+
ZEL_DRIVER_TYPE_MIXED = 3, ///< The driver has Heterogenous driver types not limited to GPU or NPU.
51+
ZEL_DRIVER_TYPE_OTHER = 4, ///< The driver has No GPU Devices and has other device types only
52+
ZEL_DRIVER_TYPE_FORCE_UINT32 = 0x7fffffff
53+
54+
} zel_driver_type_t;
55+
//////////////////////////////////////////////////////////////////////////
56+
struct driver_t
57+
{
58+
HMODULE handle = NULL;
59+
ze_result_t initStatus = ZE_RESULT_SUCCESS;
60+
ze_result_t initDriversStatus = ZE_RESULT_SUCCESS;
61+
dditable_t dditable = {};
62+
std::string name;
63+
bool driverInuse = false;
64+
zel_driver_type_t driverType;
65+
ze_driver_properties_t properties;
66+
bool pciOrderingRequested = false;
67+
};
68+
69+
using driver_vector_t = std::vector< driver_t >;
70+
71+
///////////////////////////////////////////////////////////////////////////////
72+
class context_t
73+
{
74+
public:
75+
/// factories
76+
///////////////////////////////////////////////////////////////////////////////
77+
%for obj in th.extract_objs(specs, r"handle"):
78+
%if 'class' in obj:
79+
<%
80+
81+
_handle_t = th.subt(n, tags, obj['name'])
82+
_factory_t = re.sub(r"(\w+)_handle_t", r"\1_factory_t", _handle_t)
83+
_factory = re.sub(r"(\w+)_handle_t", r"\1_factory", _handle_t)
84+
%>${th.append_ws(_factory_t, 35)} ${_factory};
85+
%endif
86+
%endfor
87+
///////////////////////////////////////////////////////////////////////////////
88+
/// end factories
89+
std::mutex image_handle_map_lock;
90+
std::mutex sampler_handle_map_lock;
91+
std::unordered_map<ze_image_object_t *, ze_image_handle_t> image_handle_map;
92+
std::unordered_map<ze_sampler_object_t *, ze_sampler_handle_t> sampler_handle_map;
93+
ze_api_version_t version = ZE_API_VERSION_CURRENT;
94+
95+
driver_vector_t allDrivers;
96+
driver_vector_t zeDrivers;
97+
driver_vector_t zesDrivers;
98+
driver_vector_t *sysmanInstanceDrivers;
99+
100+
HMODULE validationLayer = nullptr;
101+
HMODULE tracingLayer = nullptr;
102+
bool driverEnvironmentQueried = false;
103+
104+
bool forceIntercept = false;
105+
bool initDriversSupport = false;
106+
std::vector<zel_component_version_t> compVersions;
107+
const char *LOADER_COMP_NAME = "loader";
108+
109+
ze_result_t check_drivers(ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly);
110+
void debug_trace_message(std::string errorMessage, std::string errorValue);
111+
ze_result_t init();
112+
ze_result_t init_driver(driver_t &driver, ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool sysmanOnly);
113+
void add_loader_version();
114+
~context_t();
115+
bool intercept_enabled = false;
116+
bool debugTraceEnabled = false;
117+
bool tracingLayerEnabled = false;
118+
dditable_t tracing_dditable = {};
119+
std::shared_ptr<Logger> zel_logger;
120+
};
121+
122+
extern context_t *context;
123+
}

source/loader/ze_loader_internal.h

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
/*
22
*
3-
* Copyright (C) 2019-2022 Intel Corporation
3+
* Copyright (C) 2019-2025 Intel Corporation
44
*
55
* SPDX-License-Identifier: MIT
66
*
7+
* @file ze_loader_internal.h generated from ze_loader_internal.h.mako
8+
*
79
*/
810
#pragma once
911
#include <vector>
@@ -61,64 +63,63 @@ namespace loader
6163
class context_t
6264
{
6365
public:
66+
/// factories
6467
///////////////////////////////////////////////////////////////////////////////
65-
ze_driver_factory_t ze_driver_factory;
66-
ze_device_factory_t ze_device_factory;
67-
ze_context_factory_t ze_context_factory;
68-
ze_command_queue_factory_t ze_command_queue_factory;
6968
ze_command_list_factory_t ze_command_list_factory;
70-
ze_fence_factory_t ze_fence_factory;
71-
ze_event_pool_factory_t ze_event_pool_factory;
69+
ze_command_queue_factory_t ze_command_queue_factory;
70+
ze_context_factory_t ze_context_factory;
71+
ze_device_factory_t ze_device_factory;
72+
ze_driver_factory_t ze_driver_factory;
7273
ze_event_factory_t ze_event_factory;
74+
ze_event_pool_factory_t ze_event_pool_factory;
75+
ze_external_semaphore_ext_factory_t ze_external_semaphore_ext_factory;
76+
ze_fabric_edge_factory_t ze_fabric_edge_factory;
77+
ze_fabric_vertex_factory_t ze_fabric_vertex_factory;
78+
ze_fence_factory_t ze_fence_factory;
7379
ze_image_factory_t ze_image_factory;
74-
ze_module_factory_t ze_module_factory;
75-
ze_module_build_log_factory_t ze_module_build_log_factory;
7680
ze_kernel_factory_t ze_kernel_factory;
77-
ze_sampler_factory_t ze_sampler_factory;
81+
ze_module_build_log_factory_t ze_module_build_log_factory;
82+
ze_module_factory_t ze_module_factory;
7883
ze_physical_mem_factory_t ze_physical_mem_factory;
79-
ze_fabric_vertex_factory_t ze_fabric_vertex_factory;
80-
ze_fabric_edge_factory_t ze_fabric_edge_factory;
81-
ze_external_semaphore_ext_factory_t ze_external_semaphore_ext_factory;
8284
ze_rtas_builder_exp_factory_t ze_rtas_builder_exp_factory;
8385
ze_rtas_parallel_operation_exp_factory_t ze_rtas_parallel_operation_exp_factory;
84-
///////////////////////////////////////////////////////////////////////////////
85-
zes_driver_factory_t zes_driver_factory;
86+
ze_sampler_factory_t ze_sampler_factory;
8687
zes_device_factory_t zes_device_factory;
87-
zes_sched_factory_t zes_sched_factory;
88-
zes_perf_factory_t zes_perf_factory;
89-
zes_pwr_factory_t zes_pwr_factory;
90-
zes_freq_factory_t zes_freq_factory;
88+
zes_diag_factory_t zes_diag_factory;
89+
zes_driver_factory_t zes_driver_factory;
9190
zes_engine_factory_t zes_engine_factory;
92-
zes_standby_factory_t zes_standby_factory;
93-
zes_firmware_factory_t zes_firmware_factory;
94-
zes_mem_factory_t zes_mem_factory;
9591
zes_fabric_port_factory_t zes_fabric_port_factory;
96-
zes_temp_factory_t zes_temp_factory;
97-
zes_psu_factory_t zes_psu_factory;
9892
zes_fan_factory_t zes_fan_factory;
93+
zes_firmware_factory_t zes_firmware_factory;
94+
zes_freq_factory_t zes_freq_factory;
9995
zes_led_factory_t zes_led_factory;
100-
zes_ras_factory_t zes_ras_factory;
101-
zes_diag_factory_t zes_diag_factory;
96+
zes_mem_factory_t zes_mem_factory;
10297
zes_overclock_factory_t zes_overclock_factory;
98+
zes_perf_factory_t zes_perf_factory;
99+
zes_psu_factory_t zes_psu_factory;
100+
zes_pwr_factory_t zes_pwr_factory;
101+
zes_ras_factory_t zes_ras_factory;
102+
zes_sched_factory_t zes_sched_factory;
103+
zes_standby_factory_t zes_standby_factory;
104+
zes_temp_factory_t zes_temp_factory;
103105
zes_vf_factory_t zes_vf_factory;
104-
///////////////////////////////////////////////////////////////////////////////
105-
zet_driver_factory_t zet_driver_factory;
106-
zet_device_factory_t zet_device_factory;
107-
zet_context_factory_t zet_context_factory;
108106
zet_command_list_factory_t zet_command_list_factory;
109-
zet_module_factory_t zet_module_factory;
107+
zet_context_factory_t zet_context_factory;
108+
zet_debug_session_factory_t zet_debug_session_factory;
109+
zet_device_factory_t zet_device_factory;
110+
zet_driver_factory_t zet_driver_factory;
110111
zet_kernel_factory_t zet_kernel_factory;
111-
zet_metric_group_factory_t zet_metric_group_factory;
112+
zet_metric_decoder_exp_factory_t zet_metric_decoder_exp_factory;
112113
zet_metric_factory_t zet_metric_factory;
113-
zet_metric_streamer_factory_t zet_metric_streamer_factory;
114-
zet_metric_query_pool_factory_t zet_metric_query_pool_factory;
115-
zet_metric_query_factory_t zet_metric_query_factory;
116-
zet_tracer_exp_factory_t zet_tracer_exp_factory;
117-
zet_debug_session_factory_t zet_debug_session_factory;
114+
zet_metric_group_factory_t zet_metric_group_factory;
118115
zet_metric_programmable_exp_factory_t zet_metric_programmable_exp_factory;
116+
zet_metric_query_factory_t zet_metric_query_factory;
117+
zet_metric_query_pool_factory_t zet_metric_query_pool_factory;
118+
zet_metric_streamer_factory_t zet_metric_streamer_factory;
119119
zet_metric_tracer_exp_factory_t zet_metric_tracer_exp_factory;
120-
zet_metric_decoder_exp_factory_t zet_metric_decoder_exp_factory;
121-
///////////////////////////////////////////////////////////////////////////////
120+
zet_module_factory_t zet_module_factory;
121+
zet_tracer_exp_factory_t zet_tracer_exp_factory;
122+
/// end factories
122123
std::mutex image_handle_map_lock;
123124
std::mutex sampler_handle_map_lock;
124125
std::unordered_map<ze_image_object_t *, ze_image_handle_t> image_handle_map;

0 commit comments

Comments
 (0)