Skip to content

Commit 685c490

Browse files
committed
Make zips strict in pytensor/scan
1 parent cf9c5cb commit 685c490

File tree

4 files changed

+80
-47
lines changed

4 files changed

+80
-47
lines changed

pytensor/scan/basic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -886,7 +886,9 @@ def wrap_into_list(x):
886886
if condition is not None:
887887
outputs.append(condition)
888888
fake_nonseqs = [x.type() for x in non_seqs]
889-
fake_outputs = clone_replace(outputs, replace=dict(zip(non_seqs, fake_nonseqs)))
889+
fake_outputs = clone_replace(
890+
outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True))
891+
)
890892
all_inputs = filter(
891893
lambda x: (
892894
isinstance(x, Variable)
@@ -1049,7 +1051,7 @@ def wrap_into_list(x):
10491051
if not isinstance(arg, SharedVariable | Constant)
10501052
]
10511053

1052-
inner_replacements.update(dict(zip(other_scan_args, other_inner_args)))
1054+
inner_replacements.update(dict(zip(other_scan_args, other_inner_args, strict=True)))
10531055

10541056
if strict:
10551057
non_seqs_set = set(non_sequences if non_sequences is not None else [])
@@ -1071,7 +1073,7 @@ def wrap_into_list(x):
10711073
]
10721074

10731075
inner_replacements.update(
1074-
dict(zip(other_shared_scan_args, other_shared_inner_args))
1076+
dict(zip(other_shared_scan_args, other_shared_inner_args, strict=True))
10751077
)
10761078

10771079
##

pytensor/scan/op.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def check_broadcast(v1, v2):
171171
)
172172
size = min(v1.type.ndim, v2.type.ndim)
173173
for n, (b1, b2) in enumerate(
174-
zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:])
174+
zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:], strict=False)
175175
):
176176
if b1 != b2:
177177
a1 = n + size - v1.type.ndim + 1
@@ -578,6 +578,7 @@ def get_oinp_iinp_iout_oout_mappings(self):
578578
inner_input_indices,
579579
inner_output_indices,
580580
outer_output_indices,
581+
strict=True,
581582
):
582583
if oout != -1:
583584
mappings["outer_inp_from_outer_out"][oout] = oinp
@@ -959,7 +960,7 @@ def make_node(self, *inputs):
959960
# them have the same dtype
960961
argoffset = 0
961962
for inner_seq, outer_seq in zip(
962-
self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs)
963+
self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs), strict=True
963964
):
964965
check_broadcast(outer_seq, inner_seq)
965966
new_inputs.append(copy_var_format(outer_seq, as_var=inner_seq))
@@ -978,6 +979,7 @@ def make_node(self, *inputs):
978979
self.info.mit_mot_in_slices,
979980
self.info.mit_mot_out_slices[: self.info.n_mit_mot],
980981
self.outer_mitmot(inputs),
982+
strict=True,
981983
)
982984
):
983985
outer_mitmot = copy_var_format(_outer_mitmot, as_var=inner_mitmot[ipos])
@@ -1032,6 +1034,7 @@ def make_node(self, *inputs):
10321034
self.info.mit_sot_in_slices,
10331035
self.outer_mitsot(inputs),
10341036
self.inner_mitsot_outs(self.inner_outputs),
1037+
strict=True,
10351038
)
10361039
):
10371040
outer_mitsot = copy_var_format(_outer_mitsot, as_var=inner_mitsots[ipos])
@@ -1084,6 +1087,7 @@ def make_node(self, *inputs):
10841087
self.inner_sitsot(self.inner_inputs),
10851088
self.outer_sitsot(inputs),
10861089
self.inner_sitsot_outs(self.inner_outputs),
1090+
strict=True,
10871091
)
10881092
):
10891093
outer_sitsot = copy_var_format(_outer_sitsot, as_var=inner_sitsot)
@@ -1131,6 +1135,7 @@ def make_node(self, *inputs):
11311135
self.inner_shared(self.inner_inputs),
11321136
self.inner_shared_outs(self.inner_outputs),
11331137
self.outer_shared(inputs),
1138+
strict=True,
11341139
)
11351140
):
11361141
outer_shared = copy_var_format(_outer_shared, as_var=inner_shared)
@@ -1189,7 +1194,9 @@ def make_node(self, *inputs):
11891194
# type of tensor as the output, it is always a scalar int.
11901195
new_inputs += [as_tensor_variable(ons) for ons in self.outer_nitsot(inputs)]
11911196
for inner_nonseq, _outer_nonseq in zip(
1192-
self.inner_non_seqs(self.inner_inputs), self.outer_non_seqs(inputs)
1197+
self.inner_non_seqs(self.inner_inputs),
1198+
self.outer_non_seqs(inputs),
1199+
strict=True,
11931200
):
11941201
outer_nonseq = copy_var_format(_outer_nonseq, as_var=inner_nonseq)
11951202
new_inputs.append(outer_nonseq)
@@ -1272,7 +1279,9 @@ def __eq__(self, other):
12721279
if len(self.inner_outputs) != len(other.inner_outputs):
12731280
return False
12741281

1275-
for self_in, other_in in zip(self.inner_inputs, other.inner_inputs):
1282+
for self_in, other_in in zip(
1283+
self.inner_inputs, other.inner_inputs, strict=True
1284+
):
12761285
if self_in.type != other_in.type:
12771286
return False
12781287

@@ -1407,7 +1416,7 @@ def prepare_fgraph(self, fgraph):
14071416
fgraph.attach_feature(
14081417
Supervisor(
14091418
inp
1410-
for spec, inp in zip(wrapped_inputs, fgraph.inputs)
1419+
for spec, inp in zip(wrapped_inputs, fgraph.inputs, strict=True)
14111420
if not (
14121421
getattr(spec, "mutable", None)
14131422
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([inp]))
@@ -2087,7 +2096,9 @@ def perform(self, node, inputs, output_storage):
20872096
jout = j + offset_out
20882097
output_storage[j][0] = inner_output_storage[jout].storage[0]
20892098

2090-
pos = [(idx + 1) % store for idx, store in zip(pos, store_steps)]
2099+
pos = [
2100+
(idx + 1) % store for idx, store in zip(pos, store_steps, strict=True)
2101+
]
20912102
i = i + 1
20922103

20932104
# 6. Check if you need to re-order output buffers
@@ -2172,7 +2183,7 @@ def perform(self, node, inputs, output_storage):
21722183

21732184
def infer_shape(self, fgraph, node, input_shapes):
21742185
# input_shapes correspond to the shapes of node.inputs
2175-
for inp, inp_shp in zip(node.inputs, input_shapes):
2186+
for inp, inp_shp in zip(node.inputs, input_shapes, strict=True):
21762187
assert inp_shp is None or len(inp_shp) == inp.type.ndim
21772188

21782189
# Here we build 2 variables;
@@ -2241,7 +2252,9 @@ def infer_shape(self, fgraph, node, input_shapes):
22412252
# Non-sequences have a direct equivalent from self.inner_inputs in
22422253
# node.inputs
22432254
inner_non_sequences = self.inner_inputs[len(seqs_shape) + len(outs_shape) :]
2244-
for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]):
2255+
for in_ns, out_ns in zip(
2256+
inner_non_sequences, node.inputs[offset:], strict=True
2257+
):
22452258
out_equivalent[in_ns] = out_ns
22462259

22472260
if info.as_while:
@@ -2276,7 +2289,7 @@ def infer_shape(self, fgraph, node, input_shapes):
22762289
r = node.outputs[n_outs + x]
22772290
assert r.ndim == 1 + len(out_shape_x)
22782291
shp = [node.inputs[offset + info.n_shared_outs + x]]
2279-
for i, shp_i in zip(range(1, r.ndim), out_shape_x):
2292+
for i, shp_i in zip(range(1, r.ndim), out_shape_x, strict=True):
22802293
# Validate shp_i. v_shape_i is either None (if invalid),
22812294
# or a (variable, Boolean) tuple. The Boolean indicates
22822295
# whether variable is shp_i (if True), or an valid
@@ -2298,7 +2311,7 @@ def infer_shape(self, fgraph, node, input_shapes):
22982311
if info.as_while:
22992312
scan_outs_init = scan_outs
23002313
scan_outs = []
2301-
for o, x in zip(node.outputs, scan_outs_init):
2314+
for o, x in zip(node.outputs, scan_outs_init, strict=True):
23022315
if x is None:
23032316
scan_outs.append(None)
23042317
else:
@@ -2574,7 +2587,9 @@ def compute_all_gradients(known_grads):
25742587
dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx])
25752588
else:
25762589
disconnected_dC_dinps_t[dx] = False
2577-
for Xt, Xt_placeholder in zip(diff_outputs[info.n_mit_mot_outs :], Xts):
2590+
for Xt, Xt_placeholder in zip(
2591+
diff_outputs[info.n_mit_mot_outs :], Xts, strict=True
2592+
):
25782593
tmp = forced_replace(dC_dinps_t[dx], Xt, Xt_placeholder)
25792594
dC_dinps_t[dx] = tmp
25802595

@@ -2654,7 +2669,9 @@ def compute_all_gradients(known_grads):
26542669
n = n_steps.tag.test_value
26552670
else:
26562671
n = inputs[0].tag.test_value
2657-
for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs)):
2672+
for taps, x in zip(
2673+
info.mit_sot_in_slices, self.outer_mitsot_outs(outs), strict=True
2674+
):
26582675
mintap = np.min(taps)
26592676
if hasattr(x[::-1][:mintap], "test_value"):
26602677
assert x[::-1][:mintap].tag.test_value.shape[0] == n
@@ -2669,7 +2686,9 @@ def compute_all_gradients(known_grads):
26692686
assert x[::-1].tag.test_value.shape[0] == n
26702687
outer_inp_seqs += [
26712688
x[::-1][: np.min(taps)]
2672-
for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs))
2689+
for taps, x in zip(
2690+
info.mit_sot_in_slices, self.outer_mitsot_outs(outs), strict=True
2691+
)
26732692
]
26742693
outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)]
26752694
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
@@ -3000,6 +3019,7 @@ def compute_all_gradients(known_grads):
30003019
zip(
30013020
outputs[offset : offset + info.n_seqs],
30023021
type_outs[offset : offset + info.n_seqs],
3022+
strict=True,
30033023
)
30043024
):
30053025
if t == "connected":
@@ -3029,7 +3049,7 @@ def compute_all_gradients(known_grads):
30293049
gradients.append(NullType(t)())
30303050

30313051
end = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
3032-
for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])):
3052+
for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end], strict=True)):
30333053
if t == "connected":
30343054
# If the forward scan is in as_while mode, we need to pad
30353055
# the gradients, so that they match the size of the input
@@ -3064,7 +3084,7 @@ def compute_all_gradients(known_grads):
30643084
for idx in range(info.n_shared_outs):
30653085
disconnected = True
30663086
connected_flags = self.connection_pattern(node)[idx + start]
3067-
for dC_dout, connected in zip(dC_douts, connected_flags):
3087+
for dC_dout, connected in zip(dC_douts, connected_flags, strict=True):
30683088
if not isinstance(dC_dout.type, DisconnectedType) and connected:
30693089
disconnected = False
30703090
if disconnected:
@@ -3081,7 +3101,9 @@ def compute_all_gradients(known_grads):
30813101
begin = end
30823102

30833103
end = begin + n_sitsot_outs
3084-
for p, (x, t) in enumerate(zip(outputs[begin:end], type_outs[begin:end])):
3104+
for p, (x, t) in enumerate(
3105+
zip(outputs[begin:end], type_outs[begin:end], strict=True)
3106+
):
30853107
if t == "connected":
30863108
gradients.append(x[-1])
30873109
elif t == "disconnected":
@@ -3158,7 +3180,7 @@ def R_op(self, inputs, eval_points):
31583180
e = 1 + info.n_seqs
31593181
ie = info.n_seqs
31603182
clean_eval_points = []
3161-
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
3183+
for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True):
31623184
if evp is not None:
31633185
clean_eval_points.append(evp)
31643186
else:
@@ -3173,7 +3195,7 @@ def R_op(self, inputs, eval_points):
31733195
ib = ie
31743196
ie = ie + int(sum(len(x) for x in info.mit_mot_in_slices))
31753197
clean_eval_points = []
3176-
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
3198+
for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True):
31773199
if evp is not None:
31783200
clean_eval_points.append(evp)
31793201
else:
@@ -3188,7 +3210,7 @@ def R_op(self, inputs, eval_points):
31883210
ib = ie
31893211
ie = ie + int(sum(len(x) for x in info.mit_sot_in_slices))
31903212
clean_eval_points = []
3191-
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
3213+
for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True):
31923214
if evp is not None:
31933215
clean_eval_points.append(evp)
31943216
else:
@@ -3203,7 +3225,7 @@ def R_op(self, inputs, eval_points):
32033225
ib = ie
32043226
ie = ie + info.n_sit_sot
32053227
clean_eval_points = []
3206-
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
3228+
for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True):
32073229
if evp is not None:
32083230
clean_eval_points.append(evp)
32093231
else:
@@ -3227,7 +3249,7 @@ def R_op(self, inputs, eval_points):
32273249

32283250
# All other arguments
32293251
clean_eval_points = []
3230-
for inp, evp in zip(inputs[e:], eval_points[e:]):
3252+
for inp, evp in zip(inputs[e:], eval_points[e:], strict=True):
32313253
if evp is not None:
32323254
clean_eval_points.append(evp)
32333255
else:

0 commit comments

Comments
 (0)