@@ -62,44 +62,32 @@ def generate_header(output_file, kernel_name_dict):
6262 fout .write (rendered )
6363
6464
65- def get_mangled_names (dpcxx_path , source_file , output_header ):
65+ def get_mangled_names (source_file , output_header ):
6666 """Return a list of all the entry point names from a given sycl source file.
6767
6868 Filters out wrapper and offset handler entry points.
6969 """
7070 output_dir = os .path .dirname (output_header )
71- il_file = os .path .join (output_dir , os .path .basename (source_file ) + ".ll" )
72- generate_il_command = f"""\
73- { dpcxx_path } -S -fsycl -fsycl-device-code-split=off \
74- -fsycl-device-only -o { il_file } { source_file } """
75- subprocess .run (generate_il_command , shell = True )
76- kernel_line_regex = re .compile ("define.*spir_kernel" )
77- definition_lines = []
78- with open (il_file ) as f :
71+ name = os .path .splitext (os .path .basename (source_file ))[0 ]
72+ ih_file = os .path .join (output_dir , name , name + ".ih" )
73+ definitions = []
74+ writing = False
75+ with open (ih_file ) as f :
7976 lines = f .readlines ()
8077 for line in lines :
81- if kernel_line_regex .search (line ) is not None :
82- definition_lines .append (line )
78+ if "}" in line and writing :
79+ break
80+ # __pf_kernel_wrapper seems to be an internal function used by dpcpp
81+ if writing and "19__pf_kernel_wrapper" not in line :
82+ definitions .append (line .replace ("," , "" ).strip ()[1 :- 1 ])
83+ if "const char* const kernel_names[] = {" in line :
84+ writing = True
8385
84- entry_point_names = []
85- kernel_name_regex = re .compile (r"@(.*?)\(" )
86- for line in definition_lines :
87- if kernel_name_regex .search (line ) is None :
88- continue
89- match = kernel_name_regex .search (line )
90- assert isinstance (match , re .Match )
91- kernel_name = match .group (1 )
92- if "kernel_wrapper" not in kernel_name and "with_offset" not in kernel_name :
93- entry_point_names .append (kernel_name )
94-
95- os .remove (il_file )
96- return entry_point_names
86+ return definitions
9787
9888
9989def main ():
10090 parser = argparse .ArgumentParser ()
101- parser .add_argument ("--dpcxx_path" ,
102- help = "Full path to dpc++ compiler executable." )
10391 parser .add_argument (
10492 "-o" ,
10593 "--output" ,
@@ -112,7 +100,7 @@ def main():
112100 for source_file in args .source_files :
113101 program_name = os .path .splitext (os .path .basename (source_file ))[0 ]
114102 mangled_names [program_name ] = get_mangled_names (
115- args . dpcxx_path , source_file , args .output )
103+ source_file , args .output )
116104 generate_header (args .output , mangled_names )
117105
118106
0 commit comments