@@ -62,44 +62,32 @@ def generate_header(output_file, kernel_name_dict):
62
62
fout .write (rendered )
63
63
64
64
65
- def get_mangled_names (dpcxx_path , source_file , output_header ):
65
+ def get_mangled_names (source_file , output_header ):
66
66
"""Return a list of all the entry point names from a given sycl source file.
67
67
68
68
Filters out wrapper and offset handler entry points.
69
69
"""
70
70
output_dir = os .path .dirname (output_header )
71
- il_file = os .path .join (output_dir , os .path .basename (source_file ) + ".ll" )
72
- generate_il_command = f"""\
73
- { dpcxx_path } -S -fsycl -fsycl-device-code-split=off \
74
- -fsycl-device-only -o { il_file } { source_file } """
75
- subprocess .run (generate_il_command , shell = True )
76
- kernel_line_regex = re .compile ("define.*spir_kernel" )
77
- definition_lines = []
78
- with open (il_file ) as f :
71
+ name = os .path .splitext (os .path .basename (source_file ))[0 ]
72
+ ih_file = os .path .join (output_dir , name , name + ".ih" )
73
+ definitions = []
74
+ writing = False
75
+ with open (ih_file ) as f :
79
76
lines = f .readlines ()
80
77
for line in lines :
81
- if kernel_line_regex .search (line ) is not None :
82
- definition_lines .append (line )
78
+ if "}" in line and writing :
79
+ break
80
+ # __pf_kernel_wrapper seems to be an internal function used by dpcpp
81
+ if writing and "19__pf_kernel_wrapper" not in line :
82
+ definitions .append (line .replace ("," , "" ).strip ()[1 :- 1 ])
83
+ if "const char* const kernel_names[] = {" in line :
84
+ writing = True
83
85
84
- entry_point_names = []
85
- kernel_name_regex = re .compile (r"@(.*?)\(" )
86
- for line in definition_lines :
87
- if kernel_name_regex .search (line ) is None :
88
- continue
89
- match = kernel_name_regex .search (line )
90
- assert isinstance (match , re .Match )
91
- kernel_name = match .group (1 )
92
- if "kernel_wrapper" not in kernel_name and "with_offset" not in kernel_name :
93
- entry_point_names .append (kernel_name )
94
-
95
- os .remove (il_file )
96
- return entry_point_names
86
+ return definitions
97
87
98
88
99
89
def main ():
100
90
parser = argparse .ArgumentParser ()
101
- parser .add_argument ("--dpcxx_path" ,
102
- help = "Full path to dpc++ compiler executable." )
103
91
parser .add_argument (
104
92
"-o" ,
105
93
"--output" ,
@@ -112,7 +100,7 @@ def main():
112
100
for source_file in args .source_files :
113
101
program_name = os .path .splitext (os .path .basename (source_file ))[0 ]
114
102
mangled_names [program_name ] = get_mangled_names (
115
- args . dpcxx_path , source_file , args .output )
103
+ source_file , args .output )
116
104
generate_header (args .output , mangled_names )
117
105
118
106
0 commit comments