@@ -206,7 +206,13 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
206
206
""" .format (** dict (locals (), ** sub ))
207
207
208
208
209
- def make_loop (loop_orders , dtypes , loop_tasks , sub , openmp = None ):
209
+ def make_loop (
210
+ loop_orders : list [tuple [int | str , ...]],
211
+ dtypes : list ,
212
+ loop_tasks : list ,
213
+ sub : dict [str , str ],
214
+ openmp : bool = False ,
215
+ ):
210
216
"""
211
217
Make a nested loop over several arrays and associate specific code
212
218
to each level of nesting.
@@ -224,7 +230,7 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
224
230
string is code to be executed before the ith loop starts, the second
225
231
one contains code to be executed just before going to the next element
226
232
of the ith dimension.
227
- The last element if loop_tasks is a single string, containing code
233
+ The last element of loop_tasks is a single string, containing code
228
234
to be executed at the very end.
229
235
sub : dictionary
230
236
Maps 'lv#' to a suitable variable name.
@@ -257,7 +263,7 @@ def loop_over(preloop, code, indices, i):
257
263
}}
258
264
"""
259
265
260
- preloops = {}
266
+ preloops : dict [ int , str ] = {}
261
267
for i , (loop_order , dtype ) in enumerate (zip (loop_orders , dtypes , strict = True )):
262
268
for j , index in enumerate (loop_order ):
263
269
if index != "x" :
@@ -274,16 +280,8 @@ def loop_over(preloop, code, indices, i):
274
280
275
281
s = ""
276
282
277
- for i , (pre_task , task ), indices in reversed (
278
- list (
279
- zip (
280
- range (len (loop_tasks ) - 1 ),
281
- loop_tasks ,
282
- list (zip (* loop_orders , strict = True )),
283
- strict = False ,
284
- )
285
- )
286
- ):
283
+ tasks_indices = zip (loop_tasks [:- 1 ], zip (* loop_orders , strict = True ), strict = True )
284
+ for i , ((pre_task , task ), indices ) in reversed (list (enumerate (tasks_indices ))):
287
285
s = loop_over (preloops .get (i , "" ) + pre_task , s + task , indices , i )
288
286
289
287
s += loop_tasks [- 1 ]
@@ -549,16 +547,10 @@ def loop_over(preloop, code, indices, i):
549
547
s = preloops .get (0 , "" )
550
548
else :
551
549
s = ""
552
- for i , (pre_task , task ), indices in reversed (
553
- list (
554
- zip (
555
- range (len (loop_tasks ) - 1 ),
556
- loop_tasks ,
557
- list (zip (* loop_orders , strict = True )),
558
- strict = False ,
559
- )
560
- )
561
- ):
550
+ tasks_indices = zip (
551
+ loop_tasks [:- 1 ], zip (* loop_orders , strict = True ), strict = True
552
+ )
553
+ for i , ((pre_task , task ), indices ) in reversed (list (enumerate (tasks_indices ))):
562
554
s = loop_over (preloops .get (i , "" ) + pre_task , s + task , indices , i )
563
555
564
556
s += loop_tasks [- 1 ]
0 commit comments