Skip to content

Commit 8310fe7

Browse files
committed
Make the remaining zips strict
1 parent cf43497 commit 8310fe7

File tree

3 files changed

+27
-21
lines changed

3 files changed

+27
-21
lines changed

pytensor/gradient.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def Rop(
213213

214214
# Check that each element of wrt corresponds to an element
215215
# of eval_points with the same dimensionality.
216-
for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points)):
216+
for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points, strict=True)):
217217
try:
218218
if wrt_elem.type.ndim != eval_point.type.ndim:
219219
raise ValueError(
@@ -262,7 +262,7 @@ def _traverse(node):
262262
seen_nodes[inp.owner][inp.owner.outputs.index(inp)]
263263
)
264264
same_type_eval_points = []
265-
for x, y in zip(inputs, local_eval_points):
265+
for x, y in zip(inputs, local_eval_points, strict=True):
266266
if y is not None:
267267
if not isinstance(x, Variable):
268268
x = pytensor.tensor.as_tensor_variable(x)
@@ -399,7 +399,7 @@ def Lop(
399399
_wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt]
400400

401401
assert len(_f) == len(grads)
402-
known = dict(zip(_f, grads))
402+
known = dict(zip(_f, grads, strict=True))
403403

404404
ret = grad(
405405
cost=None,
@@ -778,7 +778,7 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
778778
for i in range(len(grads)):
779779
grads[i] += cost_grads[i]
780780

781-
pgrads = dict(zip(params, grads))
781+
pgrads = dict(zip(params, grads, strict=True))
782782
# separate wrt from end grads:
783783
wrt_grads = [pgrads[k] for k in wrt]
784784
end_grads = [pgrads[k] for k in end]
@@ -1045,7 +1045,7 @@ def access_term_cache(node):
10451045
in [
10461046
input_to_output and output_to_cost
10471047
for input_to_output, output_to_cost in zip(
1048-
input_to_outputs, outputs_connected
1048+
input_to_outputs, outputs_connected, strict=True
10491049
)
10501050
]
10511051
)
@@ -1071,7 +1071,7 @@ def access_term_cache(node):
10711071
not in [
10721072
in_to_out and out_to_cost and not out_nan
10731073
for in_to_out, out_to_cost, out_nan in zip(
1074-
in_to_outs, outputs_connected, ograd_is_nan
1074+
in_to_outs, outputs_connected, ograd_is_nan, strict=True
10751075
)
10761076
]
10771077
)
@@ -1131,7 +1131,7 @@ def try_to_copy_if_needed(var):
11311131
# DO NOT force integer variables to have integer dtype.
11321132
# This is a violation of the op contract.
11331133
new_output_grads = []
1134-
for o, og in zip(node.outputs, output_grads):
1134+
for o, og in zip(node.outputs, output_grads, strict=True):
11351135
o_dt = getattr(o.type, "dtype", None)
11361136
og_dt = getattr(og.type, "dtype", None)
11371137
if (
@@ -1145,7 +1145,7 @@ def try_to_copy_if_needed(var):
11451145

11461146
# Make sure that, if new_output_grads[i] has a floating point
11471147
# dtype, it is the same dtype as outputs[i]
1148-
for o, ng in zip(node.outputs, new_output_grads):
1148+
for o, ng in zip(node.outputs, new_output_grads, strict=True):
11491149
o_dt = getattr(o.type, "dtype", None)
11501150
ng_dt = getattr(ng.type, "dtype", None)
11511151
if (
@@ -1167,7 +1167,9 @@ def try_to_copy_if_needed(var):
11671167
# by the user, not computed by Op.grad, and some gradients are
11681168
# only computed and returned, but never passed as another
11691169
# node's output grads.
1170-
for idx, packed in enumerate(zip(node.outputs, new_output_grads)):
1170+
for idx, packed in enumerate(
1171+
zip(node.outputs, new_output_grads, strict=True)
1172+
):
11711173
orig_output, new_output_grad = packed
11721174
if not hasattr(orig_output, "shape"):
11731175
continue
@@ -1233,7 +1235,7 @@ def try_to_copy_if_needed(var):
12331235
not in [
12341236
in_to_out and out_to_cost and not out_int
12351237
for in_to_out, out_to_cost, out_int in zip(
1236-
in_to_outs, outputs_connected, output_is_int
1238+
in_to_outs, outputs_connected, output_is_int, strict=True
12371239
)
12381240
]
12391241
)
@@ -1314,7 +1316,7 @@ def try_to_copy_if_needed(var):
13141316
# Check that op.connection_pattern matches the connectivity
13151317
# logic driving the op.grad method
13161318
for i, (ipt, ig, connected) in enumerate(
1317-
zip(inputs, input_grads, inputs_connected)
1319+
zip(inputs, input_grads, inputs_connected, strict=True)
13181320
):
13191321
actually_connected = not isinstance(ig.type, DisconnectedType)
13201322

@@ -1601,7 +1603,7 @@ def abs_rel_errors(self, g_pt):
16011603
if len(g_pt) != len(self.gf):
16021604
raise ValueError("argument has wrong number of elements", len(g_pt))
16031605
errs = []
1604-
for i, (a, b) in enumerate(zip(g_pt, self.gf)):
1606+
for i, (a, b) in enumerate(zip(g_pt, self.gf, strict=True)):
16051607
if a.shape != b.shape:
16061608
raise ValueError(
16071609
f"argument element {i} has wrong shapes {a.shape}, {b.shape}"

pytensor/ifelse.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def make_node(self, condition: "TensorLike", *true_false_branches: Any):
170170
output_vars = []
171171
new_inputs_true_branch = []
172172
new_inputs_false_branch = []
173-
for input_t, input_f in zip(inputs_true_branch, inputs_false_branch):
173+
for input_t, input_f in zip(
174+
inputs_true_branch, inputs_false_branch, strict=True
175+
):
174176
if not isinstance(input_t, Variable):
175177
input_t = as_symbolic(input_t)
176178
if not isinstance(input_f, Variable):
@@ -207,7 +209,9 @@ def make_node(self, condition: "TensorLike", *true_false_branches: Any):
207209
# allowed to have distinct shapes from either branch
208210
new_shape = tuple(
209211
s_t if s_t == s_f else None
210-
for s_t, s_f in zip(input_t.type.shape, input_f.type.shape)
212+
for s_t, s_f in zip(
213+
input_t.type.shape, input_f.type.shape, strict=True
214+
)
211215
)
212216
# TODO FIXME: The presence of this keyword is a strong
213217
# assumption. Find something that's guaranteed by the/a
@@ -301,7 +305,7 @@ def thunk():
301305
if len(ls) > 0:
302306
return ls
303307
else:
304-
for out, t in zip(outputs, input_true_branch):
308+
for out, t in zip(outputs, input_true_branch, strict=True):
305309
compute_map[out][0] = 1
306310
val = storage_map[t][0]
307311
if self.as_view:
@@ -321,7 +325,7 @@ def thunk():
321325
if len(ls) > 0:
322326
return ls
323327
else:
324-
for out, f in zip(outputs, inputs_false_branch):
328+
for out, f in zip(outputs, inputs_false_branch, strict=True):
325329
compute_map[out][0] = 1
326330
# can't view both outputs unless destroyhandler
327331
# improves
@@ -637,7 +641,7 @@ def apply(self, fgraph):
637641
old_outs += [proposal.outputs]
638642
else:
639643
old_outs += proposal.outputs
640-
pairs = list(zip(old_outs, new_outs))
644+
pairs = list(zip(old_outs, new_outs, strict=True))
641645
fgraph.replace_all_validate(pairs, reason="cond_merge")
642646

643647

@@ -736,7 +740,7 @@ def cond_merge_random_op(fgraph, main_node):
736740
old_outs += [proposal.outputs]
737741
else:
738742
old_outs += proposal.outputs
739-
pairs = list(zip(old_outs, new_outs))
743+
pairs = list(zip(old_outs, new_outs, strict=True))
740744
main_outs = clone_replace(main_node.outputs, replace=pairs)
741745
return main_outs
742746

pytensor/printing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def debugprint(
311311
)
312312

313313
for var, profile, storage_map, topo_order in zip(
314-
outputs_to_print, profile_list, storage_maps, topo_orders
314+
outputs_to_print, profile_list, storage_maps, topo_orders, strict=True
315315
):
316316
if hasattr(var.owner, "op"):
317317
if (
@@ -936,7 +936,7 @@ def pp_process(input, new_precedence):
936936
str(i): x
937937
for i, x in enumerate(
938938
pp_process(input, precedence)
939-
for input, precedence in zip(node.inputs, precedences)
939+
for input, precedence in zip(node.inputs, precedences, strict=True)
940940
)
941941
}
942942
r = pattern % d
@@ -1449,7 +1449,7 @@ def apply_name(node):
14491449
if isinstance(fct, Function):
14501450
# TODO: Get rid of all this `expanded_inputs` nonsense and use
14511451
# `fgraph.update_mapping`
1452-
function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs)
1452+
function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs, strict=True)
14531453
for i, fg_ii in reversed(list(function_inputs)):
14541454
if i.update is not None:
14551455
k = outputs.pop()

0 commit comments

Comments
 (0)