Skip to content

Commit c0779ec

Browse files
committed
Make zips strict in pytensor/tensor/rewriting
1 parent 2b08fcf commit c0779ec

File tree

7 files changed

+38
-26
lines changed

7 files changed

+38
-26
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,11 @@ def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool:
9797
if len(bx) < len(by):
9898
return True
9999
bx = bx[-len(by) :]
100-
return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by))
100+
return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by, strict=True))
101101

102102

103103
def merge_broadcastables(broadcastables):
104-
return [all(bcast) for bcast in zip(*broadcastables)]
104+
return [all(bcast) for bcast in zip(*broadcastables, strict=True)]
105105

106106

107107
def alloc_like(
@@ -1203,7 +1203,7 @@ def local_merge_alloc(fgraph, node):
12031203
# broadcasted dimensions to its inputs[0]. Eg:
12041204
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
12051205
i = 0
1206-
for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev):
1206+
for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev, strict=True):
12071207
if dim_inner != dim_outer:
12081208
if isinstance(dim_inner, Constant) and dim_inner.data == 1:
12091209
pass

pytensor/tensor/rewriting/blas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def on_import(new_node):
502502
].tag.values_eq_approx = values_eq_approx_remove_inf_nan
503503
try:
504504
fgraph.replace_all_validate_remove(
505-
list(zip(node.outputs, new_outputs)),
505+
list(zip(node.outputs, new_outputs, strict=True)),
506506
[old_dot22],
507507
reason="GemmOptimizer",
508508
# For now we disable the warning as we know case

pytensor/tensor/rewriting/blockwise.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def local_blockwise_alloc(fgraph, node):
109109
new_inputs = []
110110
batch_shapes = []
111111
can_push_any_alloc = False
112-
for inp, inp_sig in zip(node.inputs, op.inputs_sig):
112+
for inp, inp_sig in zip(node.inputs, op.inputs_sig, strict=True):
113113
if inp.owner and isinstance(inp.owner.op, Alloc):
114114
# Push batch dims from Alloc
115115
value, *shape = inp.owner.inputs
@@ -130,6 +130,7 @@ def local_blockwise_alloc(fgraph, node):
130130
for broadcastable, dim in zip(
131131
squeezed_value.type.broadcastable[:squeezed_value_batch_ndim],
132132
tuple(squeezed_value.shape)[:squeezed_value_batch_ndim],
133+
strict=True,
133134
)
134135
]
135136
squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape)
@@ -143,7 +144,7 @@ def local_blockwise_alloc(fgraph, node):
143144
tuple(
144145
1 if broadcastable else dim
145146
for broadcastable, dim in zip(
146-
inp.type.broadcastable, shape[:batch_ndim]
147+
inp.type.broadcastable, shape[:batch_ndim], strict=True
147148
)
148149
)
149150
)
@@ -166,7 +167,9 @@ def local_blockwise_alloc(fgraph, node):
166167
# We pick the most parsimonious batch dim from the pushed Alloc
167168
missing_ndim = old_out_type.ndim - new_out_type.ndim
168169
batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim]
169-
for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples
170+
for i, batch_dims in enumerate(
171+
zip(*batch_shapes, strict=True)
172+
): # Transpose shape tuples
170173
for batch_dim in batch_dims:
171174
if batch_dim == 1:
172175
continue

pytensor/tensor/rewriting/elemwise.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def apply(self, fgraph):
300300
)
301301
new_node = new_outputs[0].owner
302302

303-
for r, new_r in zip(node.outputs, new_outputs):
303+
for r, new_r in zip(node.outputs, new_outputs, strict=True):
304304
prof["nb_call_replace"] += 1
305305
fgraph.replace(
306306
r, new_r, reason="inplace_elemwise_optimizer"
@@ -1037,12 +1037,12 @@ def update_fuseable_mappings_after_fg_replace(
10371037
)
10381038
if not isinstance(composite_outputs, list):
10391039
composite_outputs = [composite_outputs]
1040-
for old_out, composite_out in zip(outputs, composite_outputs):
1040+
for old_out, composite_out in zip(outputs, composite_outputs, strict=True):
10411041
if old_out.name:
10421042
composite_out.name = old_out.name
10431043

10441044
fgraph.replace_all_validate(
1045-
list(zip(outputs, composite_outputs)),
1045+
list(zip(outputs, composite_outputs, strict=True)),
10461046
reason=self.__class__.__name__,
10471047
)
10481048
nb_replacement += 1
@@ -1118,7 +1118,7 @@ def local_useless_composite_outputs(fgraph, node):
11181118
used_inputs = [node.inputs[i] for i in used_inputs_idxs]
11191119
c = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
11201120
e = Elemwise(scalar_op=c)(*used_inputs, return_list=True)
1121-
return dict(zip([node.outputs[i] for i in used_outputs_idxs], e))
1121+
return dict(zip([node.outputs[i] for i in used_outputs_idxs], e, strict=True))
11221122

11231123

11241124
@node_rewriter([CAReduce])
@@ -1218,7 +1218,9 @@ def local_inline_composite_constants(fgraph, node):
12181218
new_outer_inputs = []
12191219
new_inner_inputs = []
12201220
inner_replacements = {}
1221-
for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs):
1221+
for outer_inp, inner_inp in zip(
1222+
node.inputs, composite_op.fgraph.inputs, strict=True
1223+
):
12221224
# Complex variables don't have a `c_literal` that can be inlined
12231225
if "complex" not in outer_inp.type.dtype:
12241226
unique_value = get_unique_constant_value(outer_inp)
@@ -1355,7 +1357,7 @@ def local_useless_2f1grad_loop(fgraph, node):
13551357

13561358
replacements = {converges: new_converges}
13571359
i = 0
1358-
for grad_var, is_used in zip(grad_vars, grad_var_is_used):
1360+
for grad_var, is_used in zip(grad_vars, grad_var_is_used, strict=True):
13591361
if not is_used:
13601362
continue
13611363
replacements[grad_var] = new_outs[i]

pytensor/tensor/rewriting/math.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,7 +1142,9 @@ def transform(self, fgraph, node):
11421142
num, denum = self.simplify(list(orig_num), list(orig_denum), out.type)
11431143

11441144
def same(x, y):
1145-
return len(x) == len(y) and all(np.all(xe == ye) for xe, ye in zip(x, y))
1145+
return len(x) == len(y) and all(
1146+
np.all(xe == ye) for xe, ye in zip(x, y, strict=True)
1147+
)
11461148

11471149
if (
11481150
same(orig_num, num)
@@ -2445,7 +2447,9 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, out_type, minscore=0):
24452447
[(n + num, d + denum, out_type) for (n, d) in neg_pairs],
24462448
)
24472449
)
2448-
for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs):
2450+
for (n, d), (nn, dd) in zip(
2451+
pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs, strict=True
2452+
):
24492453
# We calculate how many operations we are saving with the new
24502454
# num and denum
24512455
score += len(n) + div_cost * len(d) - len(nn) - div_cost * len(dd)

pytensor/tensor/rewriting/shape.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def get_shape(self, var, idx):
186186

187187
# Only change the variables and dimensions that would introduce
188188
# extra computation
189-
for new_shps, out in zip(o_shapes, node.outputs):
189+
for new_shps, out in zip(o_shapes, node.outputs, strict=True):
190190
if not hasattr(out.type, "ndim"):
191191
continue
192192

@@ -578,7 +578,7 @@ def on_import(self, fgraph, node, reason):
578578
new_shape += sh[len(new_shape) :]
579579
o_shapes[sh_idx] = tuple(new_shape)
580580

581-
for r, s in zip(node.outputs, o_shapes):
581+
for r, s in zip(node.outputs, o_shapes, strict=True):
582582
self.set_shape(r, s)
583583

584584
def on_change_input(self, fgraph, node, i, r, new_r, reason):
@@ -709,7 +709,7 @@ def same_shape(
709709
sx = canon_shapes[: len(sx)]
710710
sy = canon_shapes[len(sx) :]
711711

712-
for dx, dy in zip(sx, sy):
712+
for dx, dy in zip(sx, sy, strict=True):
713713
if not equal_computations([dx], [dy]):
714714
return False
715715

@@ -778,7 +778,7 @@ def f(fgraph, node):
778778
# rewrite.
779779
if rval.type.ndim == node.outputs[0].type.ndim and all(
780780
s1 == s1
781-
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
781+
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape, strict=True)
782782
if s1 == 1 or s2 == 1
783783
):
784784
return [rval]
@@ -817,7 +817,7 @@ def local_useless_reshape(fgraph, node):
817817
and output.type.ndim == 1
818818
and all(
819819
s1 == s2
820-
for s1, s2 in zip(inp.type.shape, output.type.shape)
820+
for s1, s2 in zip(inp.type.shape, output.type.shape, strict=True)
821821
if s1 == 1 or s2 == 1
822822
)
823823
):
@@ -1068,7 +1068,9 @@ def local_specify_shape_lift(fgraph, node):
10681068

10691069
nonbcast_dims = {
10701070
i
1071-
for i, (dim, bcast) in enumerate(zip(shape, out_broadcastable))
1071+
for i, (dim, bcast) in enumerate(
1072+
zip(shape, out_broadcastable, strict=True)
1073+
)
10721074
if (not bcast and not NoneConst.equals(dim))
10731075
}
10741076
new_elem_inps = elem_inps.copy()
@@ -1170,7 +1172,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
11701172
new_order = node.inputs[0].owner.op.new_order
11711173
inp = node.inputs[0].owner.inputs[0]
11721174
new_order_of_nonbroadcast = []
1173-
for i, s in zip(new_order, node.inputs[0].type.shape):
1175+
for i, s in zip(new_order, node.inputs[0].type.shape, strict=True):
11741176
if s != 1:
11751177
new_order_of_nonbroadcast.append(i)
11761178
no_change_in_order = all(
@@ -1197,7 +1199,7 @@ def local_useless_unbroadcast(fgraph, node):
11971199
x = node.inputs[0]
11981200
if x.type.ndim == node.outputs[0].type.ndim and all(
11991201
s1 == s2
1200-
for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape)
1202+
for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape, strict=True)
12011203
if s1 == 1 or s2 == 1
12021204
):
12031205
# No broadcastable flag was modified

pytensor/tensor/rewriting/subtensor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ def local_subtensor_of_alloc(fgraph, node):
647647
# Slices to take from val
648648
val_slices = []
649649

650-
for i, (sl, dim) in enumerate(zip(slices, dims)):
650+
for i, (sl, dim) in enumerate(zip(slices, dims, strict=True)):
651651
# If val was not copied over that dim,
652652
# we need to take the appropriate subtensor on it.
653653
if i >= n_added_dims:
@@ -1771,7 +1771,7 @@ def local_join_subtensors(fgraph, node):
17711771
if all(
17721772
idxs_nonaxis_subtensor1 == idxs_nonaxis_subtensor2
17731773
for i, (idxs_nonaxis_subtensor1, idxs_nonaxis_subtensor2) in enumerate(
1774-
zip(idxs_subtensor1, idxs_subtensor2)
1774+
zip(idxs_subtensor1, idxs_subtensor2, strict=True)
17751775
)
17761776
if i != axis
17771777
):
@@ -1913,7 +1913,7 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node):
19131913

19141914
x_batch_bcast = x.type.broadcastable[:batch_ndim]
19151915
y_batch_bcast = y.type.broadcastable[:batch_ndim]
1916-
if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast)):
1916+
if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast, strict=True)):
19171917
# Need to broadcast batch x dims
19181918
batch_shape = tuple(
19191919
x_dim if (not xb or yb) else y_dim
@@ -1922,6 +1922,7 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node):
19221922
tuple(x.shape)[:batch_ndim],
19231923
y_batch_bcast,
19241924
tuple(y.shape)[:batch_ndim],
1925+
strict=True,
19251926
)
19261927
)
19271928
core_shape = tuple(x.shape)[batch_ndim:]

0 commit comments

Comments
 (0)