Skip to content

Commit ec84ee6

Browse files
committed
Make zips strict in pytensor/graph
1 parent 685c490 commit ec84ee6

File tree

4 files changed

+38
-29
lines changed

4 files changed

+38
-29
lines changed

pytensor/graph/basic.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def clone_with_new_inputs(
272272
# as the output type depends on the input values and not just their types
273273
output_type_depends_on_input_value = self.op._output_type_depends_on_input_value
274274

275-
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
275+
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs, strict=True)):
276276
# Check if the input type changed or if the Op has output types that depend on input values
277277
if (curr.type != new.type) or output_type_depends_on_input_value:
278278
# In strict mode, the cloned graph is assumed to be mathematically equivalent to the original one.
@@ -1295,7 +1295,7 @@ def clone_node_and_cache(
12951295
if new_node.op is not node.op:
12961296
clone_d.setdefault(node.op, new_node.op)
12971297

1298-
for old_o, new_o in zip(node.outputs, new_node.outputs):
1298+
for old_o, new_o in zip(node.outputs, new_node.outputs, strict=True):
12991299
clone_d.setdefault(old_o, new_o)
13001300

13011301
return new_node
@@ -1885,7 +1885,7 @@ def equal_computations(
18851885
if in_ys is None:
18861886
in_ys = []
18871887

1888-
for x, y in zip(xs, ys):
1888+
for x, y in zip(xs, ys, strict=True):
18891889
if not isinstance(x, Variable) and not isinstance(y, Variable):
18901890
return np.array_equal(x, y)
18911891
if not isinstance(x, Variable):
@@ -1908,13 +1908,13 @@ def equal_computations(
19081908
if len(in_xs) != len(in_ys):
19091909
return False
19101910

1911-
for _x, _y in zip(in_xs, in_ys):
1911+
for _x, _y in zip(in_xs, in_ys, strict=True):
19121912
if not (_y.type.in_same_class(_x.type)):
19131913
return False
19141914

1915-
common = set(zip(in_xs, in_ys))
1915+
common = set(zip(in_xs, in_ys, strict=True))
19161916
different: set[tuple[Variable, Variable]] = set()
1917-
for dx, dy in zip(xs, ys):
1917+
for dx, dy in zip(xs, ys, strict=True):
19181918
assert isinstance(dx, Variable)
19191919
# We checked above that both dx and dy have an owner or not
19201920
if dx.owner is None:
@@ -1950,7 +1950,7 @@ def compare_nodes(nd_x, nd_y, common, different):
19501950
return False
19511951
else:
19521952
all_in_common = True
1953-
for dx, dy in zip(nd_x.outputs, nd_y.outputs):
1953+
for dx, dy in zip(nd_x.outputs, nd_y.outputs, strict=True):
19541954
if (dx, dy) in different:
19551955
return False
19561956
if (dx, dy) not in common:
@@ -1960,7 +1960,7 @@ def compare_nodes(nd_x, nd_y, common, different):
19601960
return True
19611961

19621962
# Compare the individual inputs for equality
1963-
for dx, dy in zip(nd_x.inputs, nd_y.inputs):
1963+
for dx, dy in zip(nd_x.inputs, nd_y.inputs, strict=True):
19641964
if (dx, dy) not in common:
19651965
# Equality between the variables is unknown, compare
19661966
# their respective owners, if they have some
@@ -1995,7 +1995,7 @@ def compare_nodes(nd_x, nd_y, common, different):
19951995
# If the code reaches this statement then the inputs are pair-wise
19961996
# equivalent so the outputs of the current nodes are also
19971997
# pair-wise equivalents
1998-
for dx, dy in zip(nd_x.outputs, nd_y.outputs):
1998+
for dx, dy in zip(nd_x.outputs, nd_y.outputs, strict=True):
19991999
common.add((dx, dy))
20002000

20012001
return True

pytensor/graph/op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,14 +231,14 @@ def make_node(self, *inputs: Variable) -> Apply:
231231
)
232232
if not all(
233233
expected_type.is_super(var.type)
234-
for var, expected_type in zip(inputs, self.itypes)
234+
for var, expected_type in zip(inputs, self.itypes, strict=True)
235235
):
236236
raise TypeError(
237237
f"Invalid input types for Op {self}:\n"
238238
+ "\n".join(
239239
f"Input {i}/{len(inputs)}: Expected {inp}, got {out}"
240240
for i, (inp, out) in enumerate(
241-
zip(self.itypes, (inp.type for inp in inputs)),
241+
zip(self.itypes, (inp.type for inp in inputs), strict=True),
242242
start=1,
243243
)
244244
if inp != out

pytensor/graph/replace.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def clone_replace(
7878
items = list(_format_replace(replace).items())
7979

8080
tmp_replace = [(x, x.type()) for x, y in items]
81-
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
81+
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items, strict=True)]
8282
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
8383

8484
# TODO Explain why we call it twice ?!
@@ -295,11 +295,11 @@ def vectorize_graph(
295295
inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys())
296296
new_inputs = [replace.get(inp, inp) for inp in inputs]
297297

298-
vect_vars = dict(zip(inputs, new_inputs))
298+
vect_vars = dict(zip(inputs, new_inputs, strict=True))
299299
for node in io_toposort(inputs, seq_outputs):
300300
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
301301
vect_node = vectorize_node(node, *vect_inputs)
302-
for output, vect_output in zip(node.outputs, vect_node.outputs):
302+
for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True):
303303
if output in vect_vars:
304304
# This can happen when some outputs of a multi-output node are given a replacement,
305305
# while some of the remaining outputs are still needed in the graph.

pytensor/graph/rewriting/basic.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -399,14 +399,14 @@ def print_profile(cls, stream, prof, level=0):
399399
file=stream,
400400
)
401401
ll = []
402-
for rewrite, nb_n in zip(rewrites, nb_nodes):
402+
for rewrite, nb_n in zip(rewrites, nb_nodes, strict=True):
403403
if hasattr(rewrite, "__name__"):
404404
name = rewrite.__name__
405405
else:
406406
name = rewrite.name
407407
idx = rewrites.index(rewrite)
408408
ll.append((name, rewrite.__class__.__name__, idx, *nb_n))
409-
lll = sorted(zip(prof, ll), key=lambda a: a[0])
409+
lll = sorted(zip(prof, ll, strict=True), key=lambda a: a[0])
410410

411411
for t, rewrite in lll[::-1]:
412412
i = rewrite[2]
@@ -480,7 +480,7 @@ def merge_profile(prof1, prof2):
480480

481481
new_rewrite = SequentialGraphRewriter(*new_l)
482482
new_nb_nodes = []
483-
for p1, p2 in zip(prof1[8], prof2[8]):
483+
for p1, p2 in zip(prof1[8], prof2[8], strict=True):
484484
new_nb_nodes.append((p1[0] + p2[0], p1[1] + p2[1]))
485485
new_nb_nodes.extend(prof1[8][len(new_nb_nodes) :])
486486
new_nb_nodes.extend(prof2[8][len(new_nb_nodes) :])
@@ -635,7 +635,7 @@ def process_node(self, fgraph, node):
635635

636636
inputs_match = all(
637637
node_in is cand_in
638-
for node_in, cand_in in zip(node.inputs, candidate.inputs)
638+
for node_in, cand_in in zip(node.inputs, candidate.inputs, strict=True)
639639
)
640640

641641
if inputs_match and node.op == candidate.op:
@@ -649,6 +649,7 @@ def process_node(self, fgraph, node):
649649
node.outputs,
650650
candidate.outputs,
651651
["merge"] * len(node.outputs),
652+
strict=True,
652653
)
653654
)
654655

@@ -721,7 +722,9 @@ def apply(self, fgraph):
721722
inputs_match = all(
722723
node_in is cand_in
723724
for node_in, cand_in in zip(
724-
var.owner.inputs, candidate_var.owner.inputs
725+
var.owner.inputs,
726+
candidate_var.owner.inputs,
727+
strict=True,
725728
)
726729
)
727730

@@ -1440,7 +1443,7 @@ def transform(self, fgraph, node):
14401443
repl = self.op2.make_node(*node.inputs)
14411444
if self.transfer_tags:
14421445
repl.tag = copy.copy(node.tag)
1443-
for output, new_output in zip(node.outputs, repl.outputs):
1446+
for output, new_output in zip(node.outputs, repl.outputs, strict=True):
14441447
new_output.tag = copy.copy(output.tag)
14451448
return repl.outputs
14461449

@@ -1622,7 +1625,7 @@ def transform(self, fgraph, node, get_nodes=True):
16221625
continue
16231626
ret = self.transform(fgraph, real_node, get_nodes=False)
16241627
if ret is not False and ret is not None:
1625-
return dict(zip(real_node.outputs, ret))
1628+
return dict(zip(real_node.outputs, ret, strict=True))
16261629

16271630
if node.op != self.op:
16281631
return False
@@ -1654,7 +1657,7 @@ def transform(self, fgraph, node, get_nodes=True):
16541657
len(node.outputs) == len(ret.owner.outputs)
16551658
and all(
16561659
o.type.is_super(new_o.type)
1657-
for o, new_o in zip(node.outputs, ret.owner.outputs)
1660+
for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
16581661
)
16591662
):
16601663
return False
@@ -1946,7 +1949,7 @@ def process_node(
19461949
)
19471950
# None in the replacement mean that this variable isn't used
19481951
# and we want to remove it
1949-
for r, rnew in zip(old_vars, replacements):
1952+
for r, rnew in zip(old_vars, replacements, strict=True):
19501953
if rnew is None and len(fgraph.clients[r]) > 0:
19511954
raise ValueError(
19521955
f"Node rewriter {node_rewriter} tried to remove a variable"
@@ -1956,7 +1959,7 @@ def process_node(
19561959
# the replacement
19571960
repl_pairs = [
19581961
(r, rnew)
1959-
for r, rnew in zip(old_vars, replacements)
1962+
for r, rnew in zip(old_vars, replacements, strict=True)
19601963
if rnew is not r and rnew is not None
19611964
]
19621965

@@ -2651,17 +2654,23 @@ def print_profile(cls, stream, prof, level=0):
26512654
print(blanc, "Global, final, and clean up rewriters", file=stream)
26522655
for i in range(len(loop_timing)):
26532656
print(blanc, f"Iter {int(i)}", file=stream)
2654-
for o, prof in zip(rewrite.global_rewriters, global_sub_profs[i]):
2657+
for o, prof in zip(
2658+
rewrite.global_rewriters, global_sub_profs[i], strict=True
2659+
):
26552660
try:
26562661
o.print_profile(stream, prof, level + 2)
26572662
except NotImplementedError:
26582663
print(blanc, "merge not implemented for ", o)
2659-
for o, prof in zip(rewrite.final_rewriters, final_sub_profs[i]):
2664+
for o, prof in zip(
2665+
rewrite.final_rewriters, final_sub_profs[i], strict=True
2666+
):
26602667
try:
26612668
o.print_profile(stream, prof, level + 2)
26622669
except NotImplementedError:
26632670
print(blanc, "merge not implemented for ", o)
2664-
for o, prof in zip(rewrite.cleanup_rewriters, cleanup_sub_profs[i]):
2671+
for o, prof in zip(
2672+
rewrite.cleanup_rewriters, cleanup_sub_profs[i], strict=True
2673+
):
26652674
try:
26662675
o.print_profile(stream, prof, level + 2)
26672676
except NotImplementedError:
@@ -2879,7 +2888,7 @@ def local_recursive_function(
28792888
outs, rewritten_vars = local_recursive_function(
28802889
rewrite_list, inp, rewritten_vars, depth + 1
28812890
)
2882-
for k, v in zip(inp.owner.outputs, outs):
2891+
for k, v in zip(inp.owner.outputs, outs, strict=True):
28832892
rewritten_vars[k] = v
28842893
nw_in = outs[inp.owner.outputs.index(inp)]
28852894

@@ -2897,7 +2906,7 @@ def local_recursive_function(
28972906
if ret is not False and ret is not None:
28982907
assert isinstance(ret, Sequence)
28992908
assert len(ret) == len(node.outputs), rewrite
2900-
for k, v in zip(node.outputs, ret):
2909+
for k, v in zip(node.outputs, ret, strict=True):
29012910
rewritten_vars[k] = v
29022911
results = ret
29032912
if ret[0].owner:

0 commit comments

Comments
 (0)