Skip to content

Commit b96efca

Browse files
committed
Make zips strict in pytensor/link/c
1 parent 6fb54a6 commit b96efca

File tree

4 files changed

+20
-12
lines changed

4 files changed

+20
-12
lines changed

pytensor/link/c/basic.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,11 +1119,15 @@ def __compile__(
11191119
module,
11201120
[
11211121
Container(input, storage)
1122-
for input, storage in zip(self.fgraph.inputs, input_storage)
1122+
for input, storage in zip(
1123+
self.fgraph.inputs, input_storage, strict=True
1124+
)
11231125
],
11241126
[
11251127
Container(output, storage, readonly=True)
1126-
for output, storage in zip(self.fgraph.outputs, output_storage)
1128+
for output, storage in zip(
1129+
self.fgraph.outputs, output_storage, strict=True
1130+
)
11271131
],
11281132
error_storage,
11291133
)
@@ -1891,11 +1895,11 @@ def make_all(
18911895
f,
18921896
[
18931897
Container(input, storage)
1894-
for input, storage in zip(fgraph.inputs, input_storage)
1898+
for input, storage in zip(fgraph.inputs, input_storage, strict=True)
18951899
],
18961900
[
18971901
Container(output, storage, readonly=True)
1898-
for output, storage in zip(fgraph.outputs, output_storage)
1902+
for output, storage in zip(fgraph.outputs, output_storage, strict=True)
18991903
],
19001904
thunks,
19011905
order,
@@ -1993,22 +1997,26 @@ def make_thunk(self, **kwargs):
19931997
)
19941998

19951999
def f():
1996-
for input1, input2 in zip(i1, i2):
2000+
for input1, input2 in zip(i1, i2, strict=True):
19972001
# Set the inputs to be the same in both branches.
19982002
# The copy is necessary in order for inplace ops not to
19992003
# interfere.
20002004
input2.storage[0] = copy(input1.storage[0])
2001-
for thunk1, thunk2, node1, node2 in zip(thunks1, thunks2, order1, order2):
2002-
for output, storage in zip(node1.outputs, thunk1.outputs):
2005+
for thunk1, thunk2, node1, node2 in zip(
2006+
thunks1, thunks2, order1, order2, strict=True
2007+
):
2008+
for output, storage in zip(node1.outputs, thunk1.outputs, strict=True):
20032009
if output in no_recycling:
20042010
storage[0] = None
2005-
for output, storage in zip(node2.outputs, thunk2.outputs):
2011+
for output, storage in zip(node2.outputs, thunk2.outputs, strict=True):
20062012
if output in no_recycling:
20072013
storage[0] = None
20082014
try:
20092015
thunk1()
20102016
thunk2()
2011-
for output1, output2 in zip(thunk1.outputs, thunk2.outputs):
2017+
for output1, output2 in zip(
2018+
thunk1.outputs, thunk2.outputs, strict=True
2019+
):
20122020
self.checker(output1, output2)
20132021
except Exception:
20142022
raise_with_op(fgraph, node1)

pytensor/link/c/cmodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2443,7 +2443,7 @@ def patch_ldflags(flag_list: list[str]) -> list[str]:
24432443
if not libs:
24442444
return flag_list
24452445
libs = GCC_compiler.linking_patch(lib_dirs, libs)
2446-
for flag_idx, lib in zip(flag_idxs, libs):
2446+
for flag_idx, lib in zip(flag_idxs, libs, strict=True):
24472447
flag_list[flag_idx] = lib
24482448
return flag_list
24492449

pytensor/link/c/op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def make_c_thunk(
5959
e = FunctionGraph(node.inputs, node.outputs)
6060
e_no_recycling = [
6161
new_o
62-
for (new_o, old_o) in zip(e.outputs, node.outputs)
62+
for (new_o, old_o) in zip(e.outputs, node.outputs, strict=True)
6363
if old_o in no_recycling
6464
]
6565
cl = pytensor.link.c.basic.CLinker().accept(e, no_recycling=e_no_recycling)

pytensor/link/c/params_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def c_support_code(self, **kwargs):
709709
c_init_list = []
710710
c_cleanup_list = []
711711
c_extract_list = []
712-
for attribute_name, type_instance in zip(self.fields, self.types):
712+
for attribute_name, type_instance in zip(self.fields, self.types, strict=True):
713713
try:
714714
# c_support_code() may return a code string or a list of code strings.
715715
support_code = type_instance.c_support_code()

0 commit comments

Comments
 (0)