88#
99# ==-------------------------------------------------------------------------==#
1010
11-
1211import yaml
1312import argparse
14-
1513from pathlib import Path
1614from header import HeaderFile
15+ from gpu_headers import GpuHeaderFile as GpuHeader
1716from class_implementation .classes .macro import Macro
1817from class_implementation .classes .type import Type
1918from class_implementation .classes .function import Function
2221from class_implementation .classes .object import Object
2322
2423
25- def yaml_to_classes (yaml_data ):
24+ def yaml_to_classes (yaml_data , header_class , entry_points = None ):
2625 """
2726 Convert YAML data to header classes.
2827
2928 Args:
3029 yaml_data: The YAML data containing header specifications.
30+ header_class: The class to use for creating the header.
31+ entry_points: A list of specific function names to include in the header.
3132
3233 Returns:
3334 HeaderFile: An instance of HeaderFile populated with the data.
3435 """
3536 header_name = yaml_data .get ("header" )
36- header = HeaderFile (header_name )
37+ header = header_class (header_name )
3738
3839 for macro_data in yaml_data .get ("macros" , []):
3940 header .add_macro (Macro (macro_data ["macro_name" ], macro_data ["macro_value" ]))
@@ -49,12 +50,15 @@ def yaml_to_classes(yaml_data):
4950 )
5051
5152 functions = yaml_data .get ("functions" , [])
53+ if entry_points :
54+ entry_points_set = set (entry_points )
55+ functions = [f for f in functions if f ["name" ] in entry_points_set ]
5256 sorted_functions = sorted (functions , key = lambda x : x ["name" ])
5357 guards = []
5458 guarded_function_dict = {}
5559 for function_data in sorted_functions :
5660 guard = function_data .get ("guard" , None )
57- if guard == None :
61+ if guard is None :
5862 arguments = [arg ["type" ] for arg in function_data ["arguments" ]]
5963 attributes = function_data .get ("attributes" , None )
6064 standards = function_data .get ("standards" , None )
@@ -105,19 +109,21 @@ def yaml_to_classes(yaml_data):
105109 return header
106110
107111
108- def load_yaml_file (yaml_file ):
112+ def load_yaml_file (yaml_file , header_class , entry_points ):
109113 """
110114 Load YAML file and convert it to header classes.
111115
112116 Args:
113- yaml_file: The path to the YAML file.
117+ yaml_file: Path to the YAML file.
118+ header_class: The class to use for creating the header (HeaderFile or GpuHeader).
119+ entry_points: A list of specific function names to include in the header.
114120
115121 Returns:
116- HeaderFile: An instance of HeaderFile populated with the data from the YAML file .
122+ HeaderFile: An instance of HeaderFile populated with the data.
117123 """
118124 with open (yaml_file , "r" ) as f :
119125 yaml_data = yaml .safe_load (f )
120- return yaml_to_classes (yaml_data )
126+ return yaml_to_classes (yaml_data , header_class , entry_points )
121127
122128
123129def fill_public_api (header_str , h_def_content ):
@@ -207,7 +213,14 @@ def increase_indent(self, flow=False, indentless=False):
207213 print (f"Added function { new_function .name } to { yaml_file } " )
208214
209215
210- def main (yaml_file , h_def_file , output_dir , add_function = None ):
216+ def main (
217+ yaml_file ,
218+ output_dir = None ,
219+ h_def_file = None ,
220+ add_function = None ,
221+ entry_points = None ,
222+ export_decls = False ,
223+ ):
211224 """
212225 Main function to generate header files from YAML and .h.def templates.
213226
@@ -216,41 +229,50 @@ def main(yaml_file, h_def_file, output_dir, add_function=None):
216229 h_def_file: Path to the .h.def template file.
217230 output_dir: Directory to output the generated header file.
218231 add_function: Details of the function to be added to the YAML file (if any).
232+ entry_points: A list of specific function names to include in the header.
233+ export_decls: Flag to use GpuHeader for exporting declarations.
219234 """
220-
221235 if add_function :
222236 add_function_to_yaml (yaml_file , add_function )
223237
224- header = load_yaml_file (yaml_file )
225-
226- with open (h_def_file , "r" ) as f :
227- h_def_content = f .read ()
238+ header_class = GpuHeader if export_decls else HeaderFile
239+ header = load_yaml_file (yaml_file , header_class , entry_points )
228240
229241 header_str = str (header )
230- final_header_content = fill_public_api (header_str , h_def_content )
231242
232- output_file_name = Path (h_def_file ).stem
233- output_file_path = Path (output_dir ) / output_file_name
234-
235- with open (output_file_path , "w" ) as f :
236- f .write (final_header_content )
243+ if output_dir :
244+ output_file_path = Path (output_dir )
245+ if output_file_path .is_dir ():
246+ output_file_path /= f"{ Path (yaml_file ).stem } .h"
247+ else :
248+ output_file_path = Path (f"{ Path (yaml_file ).stem } .h" )
249+
250+ if not export_decls and h_def_file :
251+ with open (h_def_file , "r" ) as f :
252+ h_def_content = f .read ()
253+ final_header_content = fill_public_api (header_str , h_def_content )
254+ with open (output_file_path , "w" ) as f :
255+ f .write (final_header_content )
256+ else :
257+ with open (output_file_path , "w" ) as f :
258+ f .write (header_str )
237259
238260 print (f"Generated header file: { output_file_path } " )
239261
240262
241263if __name__ == "__main__" :
242- parser = argparse .ArgumentParser (
243- description = "Generate header files from YAML and .h.def templates"
244- )
264+ parser = argparse .ArgumentParser (description = "Generate header files from YAML" )
245265 parser .add_argument (
246266 "yaml_file" , help = "Path to the YAML file containing header specification"
247267 )
248- parser .add_argument ("h_def_file" , help = "Path to the .h.def template file" )
249268 parser .add_argument (
250269 "--output_dir" ,
251- default = "." ,
252270 help = "Directory to output the generated header file" ,
253271 )
272+ parser .add_argument (
273+ "--h_def_file" ,
274+ help = "Path to the .h.def template file (required if not using --export_decls)" ,
275+ )
254276 parser .add_argument (
255277 "--add_function" ,
256278 nargs = 6 ,
@@ -264,6 +286,21 @@ def main(yaml_file, h_def_file, output_dir, add_function=None):
264286 ),
265287 help = "Add a function to the YAML file" ,
266288 )
289+ parser .add_argument (
290+ "--e" , action = "append" , help = "Entry point to include" , dest = "entry_points"
291+ )
292+ parser .add_argument (
293+ "--export-decls" ,
294+ action = "store_true" ,
295+ help = "Flag to use GpuHeader for exporting declarations" ,
296+ )
267297 args = parser .parse_args ()
268298
269- main (args .yaml_file , args .h_def_file , args .output_dir , args .add_function )
299+ main (
300+ args .yaml_file ,
301+ args .output_dir ,
302+ args .h_def_file ,
303+ args .add_function ,
304+ args .entry_points ,
305+ args .export_decls ,
306+ )
0 commit comments