|
12 | 12 | import subprocess |
13 | 13 | import sys |
14 | 14 |
|
15 | | -from mako.template import Template |
16 | | - |
17 | | -HEADER_TEMPLATE = Template("""/* |
| 15 | +HEADER_TEMPLATE = """/* |
18 | 16 | * |
19 | 17 | * Copyright (C) 2023 Intel Corporation |
20 | 18 | * |
21 | 19 | * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. |
22 | 20 | * See LICENSE.TXT |
23 | 21 | * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
24 | 22 | * |
25 | | - * @file ${file_name}.h |
| 23 | + * @file %s.h |
26 | 24 | * |
27 | 25 | */ |
28 | 26 |
|
|
33 | 31 | namespace uur { |
34 | 32 | namespace device_binaries { |
35 | 33 | std::map<std::string, std::vector<std::string>> program_kernel_map = { |
36 | | -% for program, entry_points in kernel_name_dict.items(): |
37 | | - {"${program}", { |
38 | | - % for entry_point in entry_points: |
39 | | - "${entry_point}", |
40 | | - % endfor |
41 | | - }}, |
42 | | -% endfor |
| 34 | +%s |
43 | 35 | }; |
44 | 36 | } |
45 | 37 | } |
46 | | -""") |
| 38 | +""" |
| 39 | + |
| 40 | +PROGRAM_TEMPLATE = """\ |
| 41 | + {"%s", { |
| 42 | +%s |
| 43 | + }}, |
| 44 | +""" |
47 | 45 |
|
| 46 | +ENTRY_POINT_TEMPLATE = """\ |
| 47 | + "%s", |
| 48 | +""" |
48 | 49 |
|
49 | 50 | def generate_header(output_file, kernel_name_dict): |
50 | 51 | """Render the template and write it to the output file.""" |
51 | 52 | file_name = os.path.basename(output_file) |
52 | | - rendered = HEADER_TEMPLATE.render(file_name=file_name, |
53 | | - kernel_name_dict=kernel_name_dict) |
| 53 | + device_binaries = "" |
| 54 | + for program, entry_points in kernel_name_dict.items(): |
| 55 | + content = "" |
| 56 | + for entry_point in entry_points: |
| 57 | + content += ENTRY_POINT_TEMPLATE % entry_point |
| 58 | + device_binaries += PROGRAM_TEMPLATE % (program, content) |
| 59 | + rendered = HEADER_TEMPLATE % (file_name, device_binaries) |
54 | 60 | rendered = re.sub(r"\r\n", r"\n", rendered) |
55 | | - |
56 | 61 | with open(output_file, "w") as fout: |
57 | 62 | fout.write(rendered) |
58 | 63 |
|
@@ -81,7 +86,9 @@ def get_mangled_names(dpcxx_path, source_file, output_header): |
81 | 86 | for line in definition_lines: |
82 | 87 | if kernel_name_regex.search(line) is None: |
83 | 88 | continue |
84 | | - kernel_name = kernel_name_regex.search(line).group(1) |
| 89 | + match = kernel_name_regex.search(line) |
| 90 | + assert isinstance(match, re.Match) |
| 91 | + kernel_name = match.group(1) |
85 | 92 | if "kernel_wrapper" not in kernel_name and "with_offset" not in kernel_name: |
86 | 93 | entry_point_names.append(kernel_name) |
87 | 94 |
|
|
0 commit comments