Skip to content

Commit b74157c

Browse files
committed
Make non-strict zips strict in tensor/elemwise_cgen
1 parent 163c67b commit b74157c

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

pytensor/tensor/elemwise_cgen.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,13 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
209209
)
210210

211211

212-
def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
212+
def make_loop(
213+
loop_orders: list[tuple[int | str, ...]],
214+
dtypes: list,
215+
loop_tasks: list,
216+
sub: dict[str, str],
217+
openmp: bool = False,
218+
):
213219
"""
214220
Make a nested loop over several arrays and associate specific code
215221
to each level of nesting.
@@ -227,7 +233,7 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
227233
string is code to be executed before the ith loop starts, the second
228234
one contains code to be executed just before going to the next element
229235
of the ith dimension.
230-
The last element if loop_tasks is a single string, containing code
236+
The last element of loop_tasks is a single string, containing code
231237
to be executed at the very end.
232238
sub : dictionary
233239
Maps 'lv#' to a suitable variable name.
@@ -260,7 +266,7 @@ def loop_over(preloop, code, indices, i):
260266
}}
261267
"""
262268

263-
preloops = {}
269+
preloops: dict[int, str] = {}
264270
for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes, strict=True)):
265271
for j, index in enumerate(loop_order):
266272
if index != "x":
@@ -277,16 +283,8 @@ def loop_over(preloop, code, indices, i):
277283

278284
s = ""
279285

280-
for i, (pre_task, task), indices in reversed(
281-
list(
282-
zip(
283-
range(len(loop_tasks) - 1),
284-
loop_tasks,
285-
list(zip(*loop_orders, strict=True)),
286-
strict=False,
287-
)
288-
)
289-
):
286+
tasks_indices = zip(loop_tasks[:-1], zip(*loop_orders, strict=True), strict=True)
287+
for i, ((pre_task, task), indices) in reversed(list(enumerate(tasks_indices))):
290288
s = loop_over(preloops.get(i, "") + pre_task, s + task, indices, i)
291289

292290
s += loop_tasks[-1]

0 commit comments

Comments
 (0)