Skip to content

Commit 0d33b37

Browse files
committed
Make non-strict zips strict in tensor/elemwise_cgen
1 parent 9ac5324 commit 0d33b37

File tree

1 file changed

+15
-23
lines changed

1 file changed

+15
-23
lines changed

pytensor/tensor/elemwise_cgen.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,13 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
208208
"""
209209

210210

211-
def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
211+
def make_loop(
212+
loop_orders: list[tuple[int | str, ...]],
213+
dtypes: list,
214+
loop_tasks: list,
215+
sub: dict[str, str],
216+
openmp: bool = False,
217+
):
212218
"""
213219
Make a nested loop over several arrays and associate specific code
214220
to each level of nesting.
@@ -226,7 +232,7 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
226232
string is code to be executed before the ith loop starts, the second
227233
one contains code to be executed just before going to the next element
228234
of the ith dimension.
229-
The last element if loop_tasks is a single string, containing code
235+
The last element of loop_tasks is a single string, containing code
230236
to be executed at the very end.
231237
sub : dictionary
232238
Maps 'lv#' to a suitable variable name.
@@ -259,7 +265,7 @@ def loop_over(preloop, code, indices, i):
259265
}}
260266
"""
261267

262-
preloops = {}
268+
preloops: dict[int, str] = {}
263269
for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes, strict=True)):
264270
for j, index in enumerate(loop_order):
265271
if index != "x":
@@ -276,16 +282,8 @@ def loop_over(preloop, code, indices, i):
276282

277283
s = ""
278284

279-
for i, (pre_task, task), indices in reversed(
280-
list(
281-
zip(
282-
range(len(loop_tasks) - 1),
283-
loop_tasks,
284-
list(zip(*loop_orders, strict=True)),
285-
strict=False,
286-
)
287-
)
288-
):
285+
tasks_indices = zip(loop_tasks[:-1], zip(*loop_orders, strict=True), strict=True)
286+
for i, ((pre_task, task), indices) in reversed(list(enumerate(tasks_indices))):
289287
s = loop_over(preloops.get(i, "") + pre_task, s + task, indices, i)
290288

291289
s += loop_tasks[-1]
@@ -549,16 +547,10 @@ def loop_over(preloop, code, indices, i):
549547
s = preloops.get(0, "")
550548
else:
551549
s = ""
552-
for i, (pre_task, task), indices in reversed(
553-
list(
554-
zip(
555-
range(len(loop_tasks) - 1),
556-
loop_tasks,
557-
list(zip(*loop_orders, strict=True)),
558-
strict=False,
559-
)
560-
)
561-
):
550+
tasks_indices = zip(
551+
loop_tasks[:-1], zip(*loop_orders, strict=True), strict=True
552+
)
553+
for i, ((pre_task, task), indices) in reversed(list(enumerate(tasks_indices))):
562554
s = loop_over(preloops.get(i, "") + pre_task, s + task, indices, i)
563555

564556
s += loop_tasks[-1]

0 commit comments

Comments
 (0)