Skip to content

Commit 9c0d35d

Browse files
committed
Make zips strict in pytensor/graph
1 parent fb1a9a6 commit 9c0d35d

File tree

4 files changed

+39
-29
lines changed

4 files changed

+39
-29
lines changed

pytensor/graph/basic.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def clone_with_new_inputs(
272272
# as the output type depends on the input values and not just their types
273273
output_type_depends_on_input_value = self.op._output_type_depends_on_input_value
274274

275-
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
275+
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs, strict=True)):
276276
# Check if the input type changed or if the Op has output types that depend on input values
277277
if (curr.type != new.type) or output_type_depends_on_input_value:
278278
# In strict mode, the cloned graph is assumed to be mathematically equivalent to the original one.
@@ -1295,7 +1295,7 @@ def clone_node_and_cache(
12951295
if new_node.op is not node.op:
12961296
clone_d.setdefault(node.op, new_node.op)
12971297

1298-
for old_o, new_o in zip(node.outputs, new_node.outputs):
1298+
for old_o, new_o in zip(node.outputs, new_node.outputs, strict=True):
12991299
clone_d.setdefault(old_o, new_o)
13001300

13011301
return new_node
@@ -1884,7 +1884,7 @@ def equal_computations(
18841884
if in_ys is None:
18851885
in_ys = []
18861886

1887-
for x, y in zip(xs, ys):
1887+
for x, y in zip(xs, ys, strict=True):
18881888
if not isinstance(x, Variable) and not isinstance(y, Variable):
18891889
return np.array_equal(x, y)
18901890
if not isinstance(x, Variable):
@@ -1907,13 +1907,13 @@ def equal_computations(
19071907
if len(in_xs) != len(in_ys):
19081908
return False
19091909

1910-
for _x, _y in zip(in_xs, in_ys):
1910+
for _x, _y in zip(in_xs, in_ys, strict=True):
19111911
if not (_y.type.in_same_class(_x.type)):
19121912
return False
19131913

1914-
common = set(zip(in_xs, in_ys))
1914+
common = set(zip(in_xs, in_ys, strict=True))
19151915
different: set[tuple[Variable, Variable]] = set()
1916-
for dx, dy in zip(xs, ys):
1916+
for dx, dy in zip(xs, ys, strict=True):
19171917
assert isinstance(dx, Variable)
19181918
# We checked above that both dx and dy have an owner or not
19191919
if dx.owner is None:
@@ -1949,7 +1949,7 @@ def compare_nodes(nd_x, nd_y, common, different):
19491949
return False
19501950
else:
19511951
all_in_common = True
1952-
for dx, dy in zip(nd_x.outputs, nd_y.outputs):
1952+
for dx, dy in zip(nd_x.outputs, nd_y.outputs, strict=True):
19531953
if (dx, dy) in different:
19541954
return False
19551955
if (dx, dy) not in common:
@@ -1959,7 +1959,7 @@ def compare_nodes(nd_x, nd_y, common, different):
19591959
return True
19601960

19611961
# Compare the individual inputs for equality
1962-
for dx, dy in zip(nd_x.inputs, nd_y.inputs):
1962+
for dx, dy in zip(nd_x.inputs, nd_y.inputs, strict=True):
19631963
if (dx, dy) not in common:
19641964
# Equality between the variables is unknown, compare
19651965
# their respective owners, if they have some
@@ -1994,7 +1994,7 @@ def compare_nodes(nd_x, nd_y, common, different):
19941994
# If the code reaches this statement then the inputs are pair-wise
19951995
# equivalent so the outputs of the current nodes are also
19961996
# pair-wise equivalents
1997-
for dx, dy in zip(nd_x.outputs, nd_y.outputs):
1997+
for dx, dy in zip(nd_x.outputs, nd_y.outputs, strict=True):
19981998
common.add((dx, dy))
19991999

20002000
return True

pytensor/graph/op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,14 +231,14 @@ def make_node(self, *inputs: Variable) -> Apply:
231231
)
232232
if not all(
233233
expected_type.is_super(var.type)
234-
for var, expected_type in zip(inputs, self.itypes)
234+
for var, expected_type in zip(inputs, self.itypes, strict=True)
235235
):
236236
raise TypeError(
237237
f"Invalid input types for Op {self}:\n"
238238
+ "\n".join(
239239
f"Input {i}/{len(inputs)}: Expected {inp}, got {out}"
240240
for i, (inp, out) in enumerate(
241-
zip(self.itypes, (inp.type for inp in inputs)),
241+
zip(self.itypes, (inp.type for inp in inputs), strict=True),
242242
start=1,
243243
)
244244
if inp != out

pytensor/graph/replace.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def clone_replace(
7878
items = list(_format_replace(replace).items())
7979

8080
tmp_replace = [(x, x.type()) for x, y in items]
81-
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
81+
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items, strict=True)]
8282
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
8383

8484
# TODO Explain why we call it twice ?!
@@ -295,11 +295,11 @@ def vectorize_graph(
295295
inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys())
296296
new_inputs = [replace.get(inp, inp) for inp in inputs]
297297

298-
vect_vars = dict(zip(inputs, new_inputs))
298+
vect_vars = dict(zip(inputs, new_inputs, strict=True))
299299
for node in io_toposort(inputs, seq_outputs):
300300
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
301301
vect_node = vectorize_node(node, *vect_inputs)
302-
for output, vect_output in zip(node.outputs, vect_node.outputs):
302+
for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True):
303303
if output in vect_vars:
304304
# This can happen when some outputs of a multi-output node are given a replacement,
305305
# while some of the remaining outputs are still needed in the graph.

pytensor/graph/rewriting/basic.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -399,14 +399,14 @@ def print_profile(cls, stream, prof, level=0):
399399
file=stream,
400400
)
401401
ll = []
402-
for rewrite, nb_n in zip(rewrites, nb_nodes):
402+
for rewrite, nb_n in zip(rewrites, nb_nodes, strict=True):
403403
if hasattr(rewrite, "__name__"):
404404
name = rewrite.__name__
405405
else:
406406
name = rewrite.name
407407
idx = rewrites.index(rewrite)
408408
ll.append((name, rewrite.__class__.__name__, idx, *nb_n))
409-
lll = sorted(zip(prof, ll), key=lambda a: a[0])
409+
lll = sorted(zip(prof, ll, strict=True), key=lambda a: a[0])
410410

411411
for t, rewrite in lll[::-1]:
412412
i = rewrite[2]
@@ -480,7 +480,8 @@ def merge_profile(prof1, prof2):
480480

481481
new_rewrite = SequentialGraphRewriter(*new_l)
482482
new_nb_nodes = [
483-
(p1[0] + p2[0], p1[1] + p2[1]) for p1, p2 in zip(prof1[8], prof2[8])
483+
(p1[0] + p2[0], p1[1] + p2[1])
484+
for p1, p2 in zip(prof1[8], prof2[8], strict=True)
484485
]
485486
new_nb_nodes.extend(prof1[8][len(new_nb_nodes) :])
486487
new_nb_nodes.extend(prof2[8][len(new_nb_nodes) :])
@@ -635,7 +636,7 @@ def process_node(self, fgraph, node):
635636

636637
inputs_match = all(
637638
node_in is cand_in
638-
for node_in, cand_in in zip(node.inputs, candidate.inputs)
639+
for node_in, cand_in in zip(node.inputs, candidate.inputs, strict=True)
639640
)
640641

641642
if inputs_match and node.op == candidate.op:
@@ -649,6 +650,7 @@ def process_node(self, fgraph, node):
649650
node.outputs,
650651
candidate.outputs,
651652
["merge"] * len(node.outputs),
653+
strict=True,
652654
)
653655
)
654656

@@ -721,7 +723,9 @@ def apply(self, fgraph):
721723
inputs_match = all(
722724
node_in is cand_in
723725
for node_in, cand_in in zip(
724-
var.owner.inputs, candidate_var.owner.inputs
726+
var.owner.inputs,
727+
candidate_var.owner.inputs,
728+
strict=True,
725729
)
726730
)
727731

@@ -1434,7 +1438,7 @@ def transform(self, fgraph, node):
14341438
repl = self.op2.make_node(*node.inputs)
14351439
if self.transfer_tags:
14361440
repl.tag = copy.copy(node.tag)
1437-
for output, new_output in zip(node.outputs, repl.outputs):
1441+
for output, new_output in zip(node.outputs, repl.outputs, strict=True):
14381442
new_output.tag = copy.copy(output.tag)
14391443
return repl.outputs
14401444

@@ -1616,7 +1620,7 @@ def transform(self, fgraph, node, get_nodes=True):
16161620
continue
16171621
ret = self.transform(fgraph, real_node, get_nodes=False)
16181622
if ret is not False and ret is not None:
1619-
return dict(zip(real_node.outputs, ret))
1623+
return dict(zip(real_node.outputs, ret, strict=True))
16201624

16211625
if node.op != self.op:
16221626
return False
@@ -1648,7 +1652,7 @@ def transform(self, fgraph, node, get_nodes=True):
16481652
len(node.outputs) == len(ret.owner.outputs)
16491653
and all(
16501654
o.type.is_super(new_o.type)
1651-
for o, new_o in zip(node.outputs, ret.owner.outputs)
1655+
for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
16521656
)
16531657
):
16541658
return False
@@ -1940,7 +1944,7 @@ def process_node(
19401944
)
19411945
# None in the replacement mean that this variable isn't used
19421946
# and we want to remove it
1943-
for r, rnew in zip(old_vars, replacements):
1947+
for r, rnew in zip(old_vars, replacements, strict=True):
19441948
if rnew is None and len(fgraph.clients[r]) > 0:
19451949
raise ValueError(
19461950
f"Node rewriter {node_rewriter} tried to remove a variable"
@@ -1950,7 +1954,7 @@ def process_node(
19501954
# the replacement
19511955
repl_pairs = [
19521956
(r, rnew)
1953-
for r, rnew in zip(old_vars, replacements)
1957+
for r, rnew in zip(old_vars, replacements, strict=True)
19541958
if rnew is not r and rnew is not None
19551959
]
19561960

@@ -2633,17 +2637,23 @@ def print_profile(cls, stream, prof, level=0):
26332637
print(blanc, "Global, final, and clean up rewriters", file=stream)
26342638
for i in range(len(loop_timing)):
26352639
print(blanc, f"Iter {int(i)}", file=stream)
2636-
for o, prof in zip(rewrite.global_rewriters, global_sub_profs[i]):
2640+
for o, prof in zip(
2641+
rewrite.global_rewriters, global_sub_profs[i], strict=True
2642+
):
26372643
try:
26382644
o.print_profile(stream, prof, level + 2)
26392645
except NotImplementedError:
26402646
print(blanc, "merge not implemented for ", o)
2641-
for o, prof in zip(rewrite.final_rewriters, final_sub_profs[i]):
2647+
for o, prof in zip(
2648+
rewrite.final_rewriters, final_sub_profs[i], strict=True
2649+
):
26422650
try:
26432651
o.print_profile(stream, prof, level + 2)
26442652
except NotImplementedError:
26452653
print(blanc, "merge not implemented for ", o)
2646-
for o, prof in zip(rewrite.cleanup_rewriters, cleanup_sub_profs[i]):
2654+
for o, prof in zip(
2655+
rewrite.cleanup_rewriters, cleanup_sub_profs[i], strict=True
2656+
):
26472657
try:
26482658
o.print_profile(stream, prof, level + 2)
26492659
except NotImplementedError:
@@ -2861,7 +2871,7 @@ def local_recursive_function(
28612871
outs, rewritten_vars = local_recursive_function(
28622872
rewrite_list, inp, rewritten_vars, depth + 1
28632873
)
2864-
for k, v in zip(inp.owner.outputs, outs):
2874+
for k, v in zip(inp.owner.outputs, outs, strict=True):
28652875
rewritten_vars[k] = v
28662876
nw_in = outs[inp.owner.outputs.index(inp)]
28672877

@@ -2879,7 +2889,7 @@ def local_recursive_function(
28792889
if ret is not False and ret is not None:
28802890
assert isinstance(ret, Sequence)
28812891
assert len(ret) == len(node.outputs), rewrite
2882-
for k, v in zip(node.outputs, ret):
2892+
for k, v in zip(node.outputs, ret, strict=True):
28832893
rewritten_vars[k] = v
28842894
results = ret
28852895
if ret[0].owner:

0 commit comments

Comments
 (0)