@@ -48,6 +48,48 @@ def __init__(self, directive: Directive, lang: Language):
4848 self .language = lang
4949
5050
51+ class ArraySize (object ):
52+ """Size of an array"""
53+
54+ def __init__ (self ):
55+ self .size = list ()
56+
57+ def __iter__ (self ):
58+ for i in self .size :
59+ yield i
60+
61+ def __len__ (self ):
62+ return len (self .size )
63+
64+ def clear (self ):
65+ self .size .clear ()
66+
67+ def get (self ) -> int :
68+ length = len (self .size )
69+ if length == 0 :
70+ return 0
71+ elif length == 1 :
72+ return self .size [0 ]
73+ else :
74+ product = 1
75+ for i in self .size :
76+ product *= i
77+ return product
78+
79+ def add (self , dim : int ) -> None :
80+ # Only allow adding valid dimensions
81+ if dim >= 1 :
82+ self .size .append (dim )
83+
84+
85+ def fortran_md_size (size : ArraySize ) -> list :
86+ """Format a multidimensional size into the correct Fortran string"""
87+ md_size = list ()
88+ for dim in size :
89+ md_size .append (f":{ dim } " )
90+ return md_size
91+
92+
5193def is_openacc (directive : Directive ) -> bool :
5294 """Check if a directive is OpenACC"""
5395 return isinstance (directive , OpenACC )
@@ -120,7 +162,7 @@ def openacc_directive_contains_data_clause(line: str) -> bool:
120162 return openacc_directive_contains_clause (line , data_clauses )
121163
122164
123- def create_data_directive_openacc (name : str , size : int , lang : Language ) -> str :
165+ def create_data_directive_openacc (name : str , size : ArraySize , lang : Language ) -> str :
124166 """Create a data directive for a given language"""
125167 if is_cxx (lang ):
126168 return create_data_directive_openacc_cxx (name , size )
@@ -129,17 +171,23 @@ def create_data_directive_openacc(name: str, size: int, lang: Language) -> str:
129171 return ""
130172
131173
132- def create_data_directive_openacc_cxx (name : str , size : int ) -> str :
174+ def create_data_directive_openacc_cxx (name : str , size : ArraySize ) -> str :
133175 """Create C++ OpenACC code to allocate and copy data"""
134- return f"#pragma acc enter data create({ name } [:{ size } ])\n #pragma acc update device({ name } [:{ size } ])\n "
176+ return f"#pragma acc enter data create({ name } [:{ size . get () } ])\n #pragma acc update device({ name } [:{ size . get () } ])\n "
135177
136178
137- def create_data_directive_openacc_fortran (name : str , size : int ) -> str :
179+ def create_data_directive_openacc_fortran (name : str , size : ArraySize ) -> str :
138180 """Create Fortran OpenACC code to allocate and copy data"""
139- return f"!$acc enter data create({ name } (:{ size } ))\n !$acc update device({ name } (:{ size } ))\n "
181+ if len (size ) == 1 :
182+ return f"!$acc enter data create({ name } (:{ size .get ()} ))\n !$acc update device({ name } (:{ size .get ()} ))\n "
183+ else :
184+ md_size = fortran_md_size (size )
185+ return (
186+ f"!$acc enter data create({ name } ({ ',' .join (md_size )} ))\n !$acc update device({ name } ({ ',' .join (md_size )} ))\n "
187+ )
140188
141189
142- def exit_data_directive_openacc (name : str , size : int , lang : Language ) -> str :
190+ def exit_data_directive_openacc (name : str , size : ArraySize , lang : Language ) -> str :
143191 """Create code to copy data back for a given language"""
144192 if is_cxx (lang ):
145193 return exit_data_directive_openacc_cxx (name , size )
@@ -148,14 +196,18 @@ def exit_data_directive_openacc(name: str, size: int, lang: Language) -> str:
148196 return ""
149197
150198
151- def exit_data_directive_openacc_cxx (name : str , size : int ) -> str :
199+ def exit_data_directive_openacc_cxx (name : str , size : ArraySize ) -> str :
152200 """Create C++ OpenACC code to copy back data"""
153- return f"#pragma acc exit data copyout({ name } [:{ size } ])\n "
201+ return f"#pragma acc exit data copyout({ name } [:{ size . get () } ])\n "
154202
155203
156- def exit_data_directive_openacc_fortran (name : str , size : int ) -> str :
204+ def exit_data_directive_openacc_fortran (name : str , size : ArraySize ) -> str :
157205 """Create Fortran OpenACC code to copy back data"""
158- return f"!$acc exit data copyout({ name } (:{ size } ))\n "
206+ if len (size ) == 1 :
207+ return f"!$acc exit data copyout({ name } (:{ size .get ()} ))\n "
208+ else :
209+ md_size = fortran_md_size (size )
210+ return f"!$acc exit data copyout({ name } ({ ',' .join (md_size )} ))\n "
159211
160212
161213def correct_kernel (kernel_name : str , line : str ) -> bool :
@@ -165,7 +217,7 @@ def correct_kernel(kernel_name: str, line: str) -> bool:
165217
166218def find_size_in_preprocessor (dimension : str , preprocessor : list ) -> int :
167219 """Find the dimension of a directive defined value in the preprocessor"""
168- ret_size = None
220+ ret_size = 0
169221 for line in preprocessor :
170222 if f"#define { dimension } " in line :
171223 try :
@@ -209,45 +261,43 @@ def extract_code(start: str, stop: str, code: str, langs: Code, kernel_name: str
209261 return sections
210262
211263
212- def parse_size (size : Any , preprocessor : list = None , dimensions : dict = None ) -> int :
264+ def parse_size (size : Any , preprocessor : list = None , dimensions : dict = None ) -> ArraySize :
213265 """Converts an arbitrary object into an integer representing memory size"""
214- ret_size = None
266+ ret_size = ArraySize ()
215267 if type (size ) is not int :
216268 try :
217269 # Try to convert the size to an integer
218- ret_size = int (size )
270+ ret_size . add ( int (size ) )
219271 except ValueError :
220272 # If size cannot be natively converted to an int, we try to derive it from the preprocessor
221- if preprocessor is not None :
222- try :
273+ try :
274+ if preprocessor is not None :
223275 if "," in size :
224- ret_size = 1
225276 for dimension in size .split ("," ):
226- ret_size *= find_size_in_preprocessor (dimension , preprocessor )
277+ ret_size . add ( find_size_in_preprocessor (dimension , preprocessor ) )
227278 else :
228- ret_size = find_size_in_preprocessor (size , preprocessor )
229- except TypeError :
230- # preprocessor is available but does not contain the dimensions
231- pass
279+ ret_size . add ( find_size_in_preprocessor (size , preprocessor ) )
280+ except TypeError :
281+ # At least one of the dimension cannot be derived from the preprocessor
282+ pass
232283 # If size cannot be natively converted, nor retrieved from the preprocessor, we check user provided values
233284 if dimensions is not None :
234285 if size in dimensions .keys ():
235286 try :
236- ret_size = int (dimensions [size ])
287+ ret_size . add ( int (dimensions [size ]) )
237288 except ValueError :
238289 # User error, no mitigation
239290 return ret_size
240291 elif "," in size :
241- ret_size = 1
242292 for dimension in size .split ("," ):
243293 try :
244- ret_size *= int (dimensions [dimension ])
294+ ret_size . add ( int (dimensions [dimension ]) )
245295 except ValueError :
246296 # User error, no mitigation
247- return None
297+ return ret_size
248298 else :
249299 # size is already an int. no need for conversion
250- ret_size = size
300+ ret_size . add ( size )
251301
252302 return ret_size
253303
@@ -297,8 +347,13 @@ def wrap_data(code: str, langs: Code, data: dict, preprocessor: list = None, use
297347 intro += create_data_directive_openacc_cxx (name , size )
298348 outro += exit_data_directive_openacc_cxx (name , size )
299349 elif is_openacc (langs .directive ) and is_fortran (langs .language ):
300- intro += create_data_directive_openacc_fortran (name , size )
301- outro += exit_data_directive_openacc_fortran (name , size )
350+ if "," in data [name ][1 ]:
351+ # Multi dimensional
352+ pass
353+ else :
354+ # One dimensional
355+ intro += create_data_directive_openacc_fortran (name , size )
356+ outro += exit_data_directive_openacc_fortran (name , size )
302357 return intro + code + outro
303358
304359
@@ -537,9 +592,9 @@ def allocate_signature_memory(data: dict, preprocessor: list = None, user_dimens
537592 p_type = data [parameter ][0 ]
538593 size = parse_size (data [parameter ][1 ], preprocessor , user_dimensions )
539594 if "*" in p_type :
540- args .append (allocate_array (p_type , size ))
595+ args .append (allocate_array (p_type , size . get () ))
541596 else :
542- args .append (allocate_scalar (p_type , size ))
597+ args .append (allocate_scalar (p_type , size . get () ))
543598
544599 return args
545600
@@ -579,11 +634,15 @@ def add_present_openacc(
579634 return new_body
580635
581636
582- def add_present_openacc_cxx (name : str , size : int ) -> str :
637+ def add_present_openacc_cxx (name : str , size : ArraySize ) -> str :
583638 """Create present clause for C++ OpenACC directive"""
584- return f" present({ name } [:{ size } ]) "
639+ return f" present({ name } [:{ size . get () } ]) "
585640
586641
587- def add_present_openacc_fortran (name : str , size : int ) -> str :
642+ def add_present_openacc_fortran (name : str , size : ArraySize ) -> str :
588643 """Create present clause for Fortran OpenACC directive"""
589- return f" present({ name } (:{ size } )) "
644+ if len (size ) == 1 :
645+ return f" present({ name } (:{ size .get ()} )) "
646+ else :
647+ md_size = fortran_md_size (size )
648+ return f" present({ name } ({ ',' .join (md_size )} )) "
0 commit comments