Skip to content

Commit 7a1569b

Browse files
committed
Make non-strict zips strict in tensor/elemwise_cgen
1 parent 34a587e commit 7a1569b

File tree

1 file changed

+15
-23
lines changed

1 file changed

+15
-23
lines changed

pytensor/tensor/elemwise_cgen.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,13 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
206206
""".format(**dict(locals(), **sub))
207207

208208

209-
def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
209+
def make_loop(
210+
loop_orders: list[tuple[int | str, ...]],
211+
dtypes: list,
212+
loop_tasks: list,
213+
sub: dict[str, str],
214+
openmp: bool = False,
215+
):
210216
"""
211217
Make a nested loop over several arrays and associate specific code
212218
to each level of nesting.
@@ -224,7 +230,7 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
224230
string is code to be executed before the ith loop starts, the second
225231
one contains code to be executed just before going to the next element
226232
of the ith dimension.
227-
The last element if loop_tasks is a single string, containing code
233+
The last element of loop_tasks is a single string, containing code
228234
to be executed at the very end.
229235
sub : dictionary
230236
Maps 'lv#' to a suitable variable name.
@@ -257,7 +263,7 @@ def loop_over(preloop, code, indices, i):
257263
}}
258264
"""
259265

260-
preloops = {}
266+
preloops: dict[int, str] = {}
261267
for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes, strict=True)):
262268
for j, index in enumerate(loop_order):
263269
if index != "x":
@@ -274,16 +280,8 @@ def loop_over(preloop, code, indices, i):
274280

275281
s = ""
276282

277-
for i, (pre_task, task), indices in reversed(
278-
list(
279-
zip(
280-
range(len(loop_tasks) - 1),
281-
loop_tasks,
282-
list(zip(*loop_orders, strict=True)),
283-
strict=False,
284-
)
285-
)
286-
):
283+
tasks_indices = zip(loop_tasks[:-1], zip(*loop_orders, strict=True), strict=True)
284+
for i, ((pre_task, task), indices) in reversed(list(enumerate(tasks_indices))):
287285
s = loop_over(preloops.get(i, "") + pre_task, s + task, indices, i)
288286

289287
s += loop_tasks[-1]
@@ -549,16 +547,10 @@ def loop_over(preloop, code, indices, i):
549547
s = preloops.get(0, "")
550548
else:
551549
s = ""
552-
for i, (pre_task, task), indices in reversed(
553-
list(
554-
zip(
555-
range(len(loop_tasks) - 1),
556-
loop_tasks,
557-
list(zip(*loop_orders, strict=True)),
558-
strict=False,
559-
)
560-
)
561-
):
550+
tasks_indices = zip(
551+
loop_tasks[:-1], zip(*loop_orders, strict=True), strict=True
552+
)
553+
for i, ((pre_task, task), indices) in reversed(list(enumerate(tasks_indices))):
562554
s = loop_over(preloops.get(i, "") + pre_task, s + task, indices, i)
563555

564556
s += loop_tasks[-1]

0 commit comments

Comments
 (0)