Skip to content

Commit 750a618

Browse files
committed
Make zips strict in pytensor/link/c
1 parent 865b29a commit 750a618

File tree

4 files changed

+20
-12
lines changed

4 files changed

+20
-12
lines changed

pytensor/link/c/basic.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,11 +1119,15 @@ def __compile__(
11191119
module,
11201120
[
11211121
Container(input, storage)
1122-
for input, storage in zip(self.fgraph.inputs, input_storage)
1122+
for input, storage in zip(
1123+
self.fgraph.inputs, input_storage, strict=True
1124+
)
11231125
],
11241126
[
11251127
Container(output, storage, readonly=True)
1126-
for output, storage in zip(self.fgraph.outputs, output_storage)
1128+
for output, storage in zip(
1129+
self.fgraph.outputs, output_storage, strict=True
1130+
)
11271131
],
11281132
error_storage,
11291133
)
@@ -1893,11 +1897,11 @@ def make_all(
18931897
f,
18941898
[
18951899
Container(input, storage)
1896-
for input, storage in zip(fgraph.inputs, input_storage)
1900+
for input, storage in zip(fgraph.inputs, input_storage, strict=True)
18971901
],
18981902
[
18991903
Container(output, storage, readonly=True)
1900-
for output, storage in zip(fgraph.outputs, output_storage)
1904+
for output, storage in zip(fgraph.outputs, output_storage, strict=True)
19011905
],
19021906
thunks,
19031907
order,
@@ -1995,22 +1999,26 @@ def make_thunk(self, **kwargs):
19951999
)
19962000

19972001
def f():
1998-
for input1, input2 in zip(i1, i2):
2002+
for input1, input2 in zip(i1, i2, strict=True):
19992003
# Set the inputs to be the same in both branches.
20002004
# The copy is necessary in order for inplace ops not to
20012005
# interfere.
20022006
input2.storage[0] = copy(input1.storage[0])
2003-
for thunk1, thunk2, node1, node2 in zip(thunks1, thunks2, order1, order2):
2004-
for output, storage in zip(node1.outputs, thunk1.outputs):
2007+
for thunk1, thunk2, node1, node2 in zip(
2008+
thunks1, thunks2, order1, order2, strict=True
2009+
):
2010+
for output, storage in zip(node1.outputs, thunk1.outputs, strict=True):
20052011
if output in no_recycling:
20062012
storage[0] = None
2007-
for output, storage in zip(node2.outputs, thunk2.outputs):
2013+
for output, storage in zip(node2.outputs, thunk2.outputs, strict=True):
20082014
if output in no_recycling:
20092015
storage[0] = None
20102016
try:
20112017
thunk1()
20122018
thunk2()
2013-
for output1, output2 in zip(thunk1.outputs, thunk2.outputs):
2019+
for output1, output2 in zip(
2020+
thunk1.outputs, thunk2.outputs, strict=True
2021+
):
20142022
self.checker(output1, output2)
20152023
except Exception:
20162024
raise_with_op(fgraph, node1)

pytensor/link/c/cmodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2443,7 +2443,7 @@ def patch_ldflags(flag_list: list[str]) -> list[str]:
24432443
if not libs:
24442444
return flag_list
24452445
libs = GCC_compiler.linking_patch(lib_dirs, libs)
2446-
for flag_idx, lib in zip(flag_idxs, libs):
2446+
for flag_idx, lib in zip(flag_idxs, libs, strict=True):
24472447
flag_list[flag_idx] = lib
24482448
return flag_list
24492449

pytensor/link/c/op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def make_c_thunk(
5959
e = FunctionGraph(node.inputs, node.outputs)
6060
e_no_recycling = [
6161
new_o
62-
for (new_o, old_o) in zip(e.outputs, node.outputs)
62+
for (new_o, old_o) in zip(e.outputs, node.outputs, strict=True)
6363
if old_o in no_recycling
6464
]
6565
cl = pytensor.link.c.basic.CLinker().accept(e, no_recycling=e_no_recycling)

pytensor/link/c/params_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def c_support_code(self, **kwargs):
709709
c_init_list = []
710710
c_cleanup_list = []
711711
c_extract_list = []
712-
for attribute_name, type_instance in zip(self.fields, self.types):
712+
for attribute_name, type_instance in zip(self.fields, self.types, strict=True):
713713
try:
714714
# c_support_code() may return a code string or a list of code strings.
715715
support_code = type_instance.c_support_code()

0 commit comments

Comments
 (0)