Skip to content

Commit 6067d4a

Browse files
authored
Generalize module format in AOT tool (#4818)
This PR modifies the aot-tool to potentially support several ze-module formats (aside from spirv and native binary). The `compiler.cpp` now has a string `format_name` template parameter instead of a boolean `is_spv` that contains a format name ("native", "spirv", ...). The string is then matched to an actual format in `get_module_format()`. Signed-off-by: dchigarev <[email protected]>
1 parent c97bd76 commit 6067d4a

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

python/triton/tools/compile.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,16 @@ def constexpr(s):
205205
"_placeholder": "",
206206
}
207207
if is_xpu():
208+
if args.generate_native_code:
209+
format_name = "native"
210+
else:
211+
format_name = "spirv"
208212
params |= {
209213
"arg_types": ", ".join(ty_to_cpp(arg) for arg in arg_types_not_1),
210214
"grf_mode": args.grf_mode,
211215
"build_flags": ccinfo.metadata.build_flags,
212216
"threads_per_warp": args.threads_per_warp,
213-
"is_spv": "false" if args.generate_native_code else "true",
217+
"format_name": format_name,
214218
}
215219
output_files = []
216220
backend_name = target.backend

third_party/intel/tools/intel/compile.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ static inline void gpuAssert(ze_result_t code, const char *file, int line) {{
2525
}}
2626

2727
// ze globals
28-
#define SPV_NAME {kernel_name}_spv
28+
#define BIN_NAME {kernel_name}_bin
2929
ze_module_handle_t {kernel_name}_mod = NULL;
3030
ze_kernel_handle_t {kernel_name}_func = NULL;
31-
unsigned char SPV_NAME[{bin_size}] = {{ {bin_data} }};
31+
unsigned char BIN_NAME[{bin_size}] = {{ {bin_data} }};
3232
// sycl globals
3333
const sycl::device sycl_device;
3434
const auto ctx =
@@ -42,15 +42,25 @@ void unload_{kernel_name}(void) {{
4242
// Not implemeted
4343
}}
4444

45+
static ze_module_format_t get_module_format(const std::string& format_name) {{
46+
if (format_name == "spirv") {{
47+
return ZE_MODULE_FORMAT_IL_SPIRV;
48+
}} else if (format_name == "native") {{
49+
return ZE_MODULE_FORMAT_NATIVE;
50+
}} else {{
51+
throw std::runtime_error("Unsupported module format");
52+
}}
53+
}}
54+
4555
void load_{kernel_name}() {{
46-
uint8_t *binary_ptr = (uint8_t *)&SPV_NAME;
56+
uint8_t *binary_ptr = (uint8_t *)&BIN_NAME;
4757
size_t binary_size = {bin_size};
4858

49-
const bool is_spv = {is_spv};
59+
const std::string format_name = "{format_name}";
5060

5161
ze_module_desc_t module_description {{}};
5262
module_description.stype = ZE_STRUCTURE_TYPE_MODULE_DESC;
53-
module_description.format = is_spv ? ZE_MODULE_FORMAT_IL_SPIRV : ZE_MODULE_FORMAT_NATIVE;
63+
module_description.format = get_module_format(format_name);
5464
module_description.inputSize = static_cast<uint32_t>(binary_size);
5565
module_description.pInputModule = binary_ptr;
5666
module_description.pBuildFlags = "{build_flags}";

0 commit comments

Comments
 (0)