1- from typing import Any
1+ from typing import Any , Tuple
22from abc import ABC , abstractmethod
33import numpy as np
44
5+ # Function templates
6+ cpp_template : str = """
7+ <!?PREPROCESSOR?!>
8+ <!?USER_DEFINES?!>
9+ #include <chrono>
10+
11+ extern "C" <!?SIGNATURE?!> {
12+ <!?INITIALIZATION?!>
13+ <!?BODY?!>
14+ <!?DEINITIALIZATION?!>
15+ }
16+ """
17+
18+ f90_template : str = """
19+ <!?PREPROCESSOR?!>
20+ <!?USER_DEFINES?!>
21+
22+ module kt
23+ use iso_c_binding
24+ contains
25+
26+ <!?SIGNATURE?!>
27+ <!?INITIALIZATION?!>
28+ <!?BODY?!>
29+ <!?DEINITIALIZATION?!>
30+ end function <!?NAME?!>
31+
32+ end module kt
33+ """
34+
535
636class Directive (ABC ):
737 """Base class for all directives"""
@@ -339,7 +369,7 @@ def wrap_timing_fortran(code: str) -> str:
339369
340370def end_timing_cxx (code : str ) -> str :
341371 """In C++ we need to return the measured time"""
342- return code + " \n return elapsed_time.count();\n "
372+ return " \n " . join ([ code , "return elapsed_time.count();\n "])
343373
344374
345375def wrap_data (code : str , langs : Code , data : dict , preprocessor : list = None , user_dimensions : dict = None ) -> str :
@@ -355,7 +385,7 @@ def wrap_data(code: str, langs: Code, data: dict, preprocessor: list = None, use
355385 elif is_openacc (langs .directive ) and is_fortran (langs .language ):
356386 intro += create_data_directive_openacc_fortran (name , size )
357387 outro += exit_data_directive_openacc_fortran (name , size )
358- return intro + code + outro
388+ return " \n " . join ([ intro , code , outro ])
359389
360390
361391def extract_directive_code (code : str , langs : Code , kernel_name : str = None ) -> dict :
@@ -529,42 +559,34 @@ def generate_directive_function(
529559) -> str :
530560 """Generate tunable function for one directive"""
531561
532- code = "\n " .join (preprocessor ) + "\n "
533- if user_dimensions is not None :
534- # add user dimensions to preprocessor
535- for key , value in user_dimensions .items ():
536- code += f"#define { key } { value } \n "
537- if is_cxx (langs .language ) and "#include <chrono>" not in preprocessor :
538- code += "\n #include <chrono>\n "
539- if is_cxx (langs .language ):
540- code += 'extern "C" ' + signature + "{\n "
541- elif is_fortran (langs .language ):
542- code += "\n module kt\n use iso_c_binding\n contains\n "
543- code += "\n " + signature
544- if len (initialization ) > 1 :
545- code += initialization + "\n "
546- if data is not None :
547- body = add_present_openacc (body , langs , data , preprocessor , user_dimensions )
548562 if is_cxx (langs .language ):
563+ code = cpp_template
549564 body = start_timing_cxx (body )
550565 if data is not None :
551- code += wrap_data (body + "\n " , langs , data , preprocessor , user_dimensions )
552- else :
553- code += body
554- code = end_timing_cxx (code )
555- if len (deinitialization ) > 1 :
556- code += deinitialization + "\n "
557- code += "\n }"
566+ body = wrap_data (body + "\n " , langs , data , preprocessor , user_dimensions )
567+ body = end_timing_cxx (body )
558568 elif is_fortran (langs .language ):
569+ code = f90_template
559570 body = wrap_timing (body , langs .language )
560571 if data is not None :
561- code += wrap_data (body + "\n " , langs , data , preprocessor , user_dimensions )
562- else :
563- code += body + "\n "
564- if len (deinitialization ) > 1 :
565- code += deinitialization + "\n "
572+ body = wrap_data (body + "\n " , langs , data , preprocessor , user_dimensions )
566573 name = signature .split (" " )[1 ].split ("(" )[0 ]
567- code += f"\n end function { name } \n end module kt\n "
574+ code = code .replace ("<!?NAME?!>" , name )
575+ code = code .replace ("<!?PREPROCESSOR?!>" , "\n " .join (preprocessor ))
576+ # if present, add user specific dimensions as defines
577+ if user_dimensions is not None :
578+ user_defines = ""
579+ for key , value in user_dimensions .items ():
580+ user_defines += f"#define { key } { value } \n "
581+ code = code .replace ("<!?USER_DEFINES?!>" , user_defines )
582+ else :
583+ code = code .replace ("<!?USER_DEFINES?!>" , "" )
584+ code = code .replace ("<!?SIGNATURE?!>" , signature )
585+ code = code .replace ("<!?INITIALIZATION?!>" , initialization )
586+ code = code .replace ("<!?DEINITIALIZATION?!>" , deinitialization )
587+ if data is not None :
588+ body = add_present_openacc (body , langs , data , preprocessor , user_dimensions )
589+ code = code .replace ("<!?BODY?!>" , body )
568590
569591 return code
570592
@@ -662,3 +684,21 @@ def add_present_openacc_fortran(name: str, size: ArraySize) -> str:
662684 else :
663685 md_size = fortran_md_size (size )
664686 return f" present({ name } ({ ',' .join (md_size )} )) "
687+
688+
689+ def process_directives (langs : Code , source : str , user_dimensions : dict = None ) -> Tuple [dict , dict ]:
690+ """Helper functions to process all the directives in the code and create tunable functions"""
691+ kernel_strings = dict ()
692+ kernel_args = dict ()
693+ preprocessor = extract_preprocessor (source )
694+ signatures = extract_directive_signature (source , langs )
695+ bodies = extract_directive_code (source , langs )
696+ data = extract_directive_data (source , langs )
697+ init = extract_initialization_code (source , langs )
698+ deinit = extract_deinitialization_code (source , langs )
699+ for kernel in signatures .keys ():
700+ kernel_strings [kernel ] = generate_directive_function (
701+ preprocessor , signatures [kernel ], bodies [kernel ], langs , data [kernel ], init , deinit , user_dimensions
702+ )
703+ kernel_args [kernel ] = allocate_signature_memory (data [kernel ], preprocessor , user_dimensions )
704+ return (kernel_strings , kernel_args )
0 commit comments