@@ -48,6 +48,47 @@ def __init__(self, directive: Directive, lang: Language):
4848 self .language = lang
4949
5050
51+ class ArraySize :
52+ """Size of an array"""
53+
54+ def __init__ (self ):
55+ self .size = list ()
56+
57+ def __iter__ (self ):
58+ for i in self .size :
59+ yield i
60+
61+ def __len__ (self ):
62+ return len (self .size )
63+
64+ def clear (self ):
65+ self .size .clear ()
66+
67+ def get (self ) -> int :
68+ length = len (self .size )
69+ if length == 0 :
70+ return 0
71+ elif length == 1 :
72+ return self .size [0 ]
73+ else :
74+ product = 1
75+ for i in self .size :
76+ product *= i
77+ return product
78+
79+ def add (self , dim : int ) -> None :
80+ # Only allow adding valid dimensions
81+ if dim >= 1 :
82+ self .size .append (dim )
83+
84+
85+ def fortran_md_size (size : ArraySize ) -> list :
86+ md_size = list ()
87+ for dim in size :
88+ md_size .append (f":{ dim } " )
89+ return md_size
90+
91+
5192def is_openacc (directive : Directive ) -> bool :
5293 """Check if a directive is OpenACC"""
5394 return isinstance (directive , OpenACC )
@@ -120,7 +161,7 @@ def openacc_directive_contains_data_clause(line: str) -> bool:
120161 return openacc_directive_contains_clause (line , data_clauses )
121162
122163
123- def create_data_directive_openacc (name : str , size : int , lang : Language ) -> str :
164+ def create_data_directive_openacc (name : str , size : ArraySize , lang : Language ) -> str :
124165 """Create a data directive for a given language"""
125166 if is_cxx (lang ):
126167 return create_data_directive_openacc_cxx (name , size )
@@ -129,17 +170,23 @@ def create_data_directive_openacc(name: str, size: int, lang: Language) -> str:
129170 return ""
130171
131172
132- def create_data_directive_openacc_cxx (name : str , size : int ) -> str :
173+ def create_data_directive_openacc_cxx (name : str , size : ArraySize ) -> str :
133174 """Create C++ OpenACC code to allocate and copy data"""
134- return f"#pragma acc enter data create({ name } [:{ size } ])\n #pragma acc update device({ name } [:{ size } ])\n "
175+ return f"#pragma acc enter data create({ name } [:{ size . get () } ])\n #pragma acc update device({ name } [:{ size . get () } ])\n "
135176
136177
137- def create_data_directive_openacc_fortran (name : str , size : int ) -> str :
178+ def create_data_directive_openacc_fortran (name : str , size : ArraySize ) -> str :
138179 """Create Fortran OpenACC code to allocate and copy data"""
139- return f"!$acc enter data create({ name } (:{ size } ))\n !$acc update device({ name } (:{ size } ))\n "
180+ if len (size ) == 1 :
181+ return f"!$acc enter data create({ name } (:{ size .get ()} ))\n !$acc update device({ name } (:{ size .get ()} ))\n "
182+ else :
183+ md_size = fortran_md_size (size )
184+ return (
185+ f"!$acc enter data create({ name } ({ ',' .join (md_size )} ))\n !$acc update device({ name } ({ ',' .join (md_size )} ))\n "
186+ )
140187
141188
142- def exit_data_directive_openacc (name : str , size : int , lang : Language ) -> str :
189+ def exit_data_directive_openacc (name : str , size : ArraySize , lang : Language ) -> str :
143190 """Create code to copy data back for a given language"""
144191 if is_cxx (lang ):
145192 return exit_data_directive_openacc_cxx (name , size )
@@ -148,14 +195,18 @@ def exit_data_directive_openacc(name: str, size: int, lang: Language) -> str:
148195 return ""
149196
150197
151- def exit_data_directive_openacc_cxx (name : str , size : int ) -> str :
198+ def exit_data_directive_openacc_cxx (name : str , size : ArraySize ) -> str :
152199 """Create C++ OpenACC code to copy back data"""
153- return f"#pragma acc exit data copyout({ name } [:{ size } ])\n "
200+ return f"#pragma acc exit data copyout({ name } [:{ size . get () } ])\n "
154201
155202
156- def exit_data_directive_openacc_fortran (name : str , size : int ) -> str :
203+ def exit_data_directive_openacc_fortran (name : str , size : ArraySize ) -> str :
157204 """Create Fortran OpenACC code to copy back data"""
158- return f"!$acc exit data copyout({ name } (:{ size } ))\n "
205+ if len (size ) == 1 :
206+ return f"!$acc exit data copyout({ name } (:{ size .get ()} ))\n "
207+ else :
208+ md_size = fortran_md_size (size )
209+ return f"!$acc exit data copyout({ name } ({ ',' .join (md_size )} ))\n "
159210
160211
161212def correct_kernel (kernel_name : str , line : str ) -> bool :
@@ -165,7 +216,7 @@ def correct_kernel(kernel_name: str, line: str) -> bool:
165216
166217def find_size_in_preprocessor (dimension : str , preprocessor : list ) -> int :
167218 """Find the dimension of a directive defined value in the preprocessor"""
168- ret_size = None
219+ ret_size = 0
169220 for line in preprocessor :
170221 if f"#define { dimension } " in line :
171222 try :
@@ -209,45 +260,43 @@ def extract_code(start: str, stop: str, code: str, langs: Code, kernel_name: str
209260 return sections
210261
211262
212- def parse_size (size : Any , preprocessor : list = None , dimensions : dict = None ) -> int :
263+ def parse_size (size : Any , preprocessor : list = None , dimensions : dict = None ) -> ArraySize :
213264 """Converts an arbitrary object into an integer representing memory size"""
214- ret_size = None
265+ ret_size = ArraySize ()
215266 if type (size ) is not int :
216267 try :
217268 # Try to convert the size to an integer
218- ret_size = int (size )
269+ ret_size . add ( int (size ) )
219270 except ValueError :
220271 # If size cannot be natively converted to an int, we try to derive it from the preprocessor
221- if preprocessor is not None :
222- try :
272+ try :
273+ if preprocessor is not None :
223274 if "," in size :
224- ret_size = 1
225275 for dimension in size .split ("," ):
226- ret_size *= find_size_in_preprocessor (dimension , preprocessor )
276+ ret_size . add ( find_size_in_preprocessor (dimension , preprocessor ) )
227277 else :
228- ret_size = find_size_in_preprocessor (size , preprocessor )
229- except TypeError :
230- # preprocessor is available but does not contain the dimensions
231- pass
278+ ret_size . add ( find_size_in_preprocessor (size , preprocessor ) )
279+ except TypeError :
280+ # At least one of the dimension cannot be derived from the preprocessor
281+ pass
232282 # If size cannot be natively converted, nor retrieved from the preprocessor, we check user provided values
233283 if dimensions is not None :
234284 if size in dimensions .keys ():
235285 try :
236- ret_size = int (dimensions [size ])
286+ ret_size . add ( int (dimensions [size ]) )
237287 except ValueError :
238288 # User error, no mitigation
239289 return ret_size
240290 elif "," in size :
241- ret_size = 1
242291 for dimension in size .split ("," ):
243292 try :
244- ret_size *= int (dimensions [dimension ])
293+ ret_size . add ( int (dimensions [dimension ]) )
245294 except ValueError :
246295 # User error, no mitigation
247- return None
296+ return ret_size
248297 else :
249298 # size is already an int. no need for conversion
250- ret_size = size
299+ ret_size . add ( size )
251300
252301 return ret_size
253302
@@ -297,8 +346,13 @@ def wrap_data(code: str, langs: Code, data: dict, preprocessor: list = None, use
297346 intro += create_data_directive_openacc_cxx (name , size )
298347 outro += exit_data_directive_openacc_cxx (name , size )
299348 elif is_openacc (langs .directive ) and is_fortran (langs .language ):
300- intro += create_data_directive_openacc_fortran (name , size )
301- outro += exit_data_directive_openacc_fortran (name , size )
349+ if "," in data [name ][1 ]:
350+ # Multi dimensional
351+ pass
352+ else :
353+ # One dimensional
354+ intro += create_data_directive_openacc_fortran (name , size )
355+ outro += exit_data_directive_openacc_fortran (name , size )
302356 return intro + code + outro
303357
304358
@@ -537,9 +591,9 @@ def allocate_signature_memory(data: dict, preprocessor: list = None, user_dimens
537591 p_type = data [parameter ][0 ]
538592 size = parse_size (data [parameter ][1 ], preprocessor , user_dimensions )
539593 if "*" in p_type :
540- args .append (allocate_array (p_type , size ))
594+ args .append (allocate_array (p_type , size . get () ))
541595 else :
542- args .append (allocate_scalar (p_type , size ))
596+ args .append (allocate_scalar (p_type , size . get () ))
543597
544598 return args
545599
@@ -579,11 +633,15 @@ def add_present_openacc(
579633 return new_body
580634
581635
582- def add_present_openacc_cxx (name : str , size : int ) -> str :
636+ def add_present_openacc_cxx (name : str , size : ArraySize ) -> str :
583637 """Create present clause for C++ OpenACC directive"""
584- return f" present({ name } [:{ size } ]) "
638+ return f" present({ name } [:{ size . get () } ]) "
585639
586640
587- def add_present_openacc_fortran (name : str , size : int ) -> str :
641+ def add_present_openacc_fortran (name : str , size : ArraySize ) -> str :
588642 """Create present clause for Fortran OpenACC directive"""
589- return f" present({ name } (:{ size } )) "
643+ if len (size ) == 1 :
644+ return f" present({ name } (:{ size .get ()} )) "
645+ else :
646+ md_size = fortran_md_size (size )
647+ return f" present({ name } ({ ',' .join (md_size )} )) "
0 commit comments