Skip to content

Commit fb1a9a6

Browse files
committed
Make zips strict in pytensor/scan
1 parent f3845d1 commit fb1a9a6

File tree

4 files changed

+80
-47
lines changed

4 files changed

+80
-47
lines changed

pytensor/scan/basic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,9 @@ def wrap_into_list(x):
884884
if condition is not None:
885885
outputs.append(condition)
886886
fake_nonseqs = [x.type() for x in non_seqs]
887-
fake_outputs = clone_replace(outputs, replace=dict(zip(non_seqs, fake_nonseqs)))
887+
fake_outputs = clone_replace(
888+
outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True))
889+
)
888890
all_inputs = filter(
889891
lambda x: (
890892
isinstance(x, Variable)
@@ -1047,7 +1049,7 @@ def wrap_into_list(x):
10471049
if not isinstance(arg, SharedVariable | Constant)
10481050
]
10491051

1050-
inner_replacements.update(dict(zip(other_scan_args, other_inner_args)))
1052+
inner_replacements.update(dict(zip(other_scan_args, other_inner_args, strict=True)))
10511053

10521054
if strict:
10531055
non_seqs_set = set(non_sequences if non_sequences is not None else [])
@@ -1069,7 +1071,7 @@ def wrap_into_list(x):
10691071
]
10701072

10711073
inner_replacements.update(
1072-
dict(zip(other_shared_scan_args, other_shared_inner_args))
1074+
dict(zip(other_shared_scan_args, other_shared_inner_args, strict=True))
10731075
)
10741076

10751077
##

pytensor/scan/op.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def check_broadcast(v1, v2):
170170
)
171171
size = min(v1.type.ndim, v2.type.ndim)
172172
for n, (b1, b2) in enumerate(
173-
zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:])
173+
zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:], strict=False)
174174
):
175175
if b1 != b2:
176176
a1 = n + size - v1.type.ndim + 1
@@ -577,6 +577,7 @@ def get_oinp_iinp_iout_oout_mappings(self):
577577
inner_input_indices,
578578
inner_output_indices,
579579
outer_output_indices,
580+
strict=True,
580581
):
581582
if oout != -1:
582583
mappings["outer_inp_from_outer_out"][oout] = oinp
@@ -958,7 +959,7 @@ def make_node(self, *inputs):
958959
# them have the same dtype
959960
argoffset = 0
960961
for inner_seq, outer_seq in zip(
961-
self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs)
962+
self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs), strict=True
962963
):
963964
check_broadcast(outer_seq, inner_seq)
964965
new_inputs.append(copy_var_format(outer_seq, as_var=inner_seq))
@@ -977,6 +978,7 @@ def make_node(self, *inputs):
977978
self.info.mit_mot_in_slices,
978979
self.info.mit_mot_out_slices[: self.info.n_mit_mot],
979980
self.outer_mitmot(inputs),
981+
strict=True,
980982
)
981983
):
982984
outer_mitmot = copy_var_format(_outer_mitmot, as_var=inner_mitmot[ipos])
@@ -1031,6 +1033,7 @@ def make_node(self, *inputs):
10311033
self.info.mit_sot_in_slices,
10321034
self.outer_mitsot(inputs),
10331035
self.inner_mitsot_outs(self.inner_outputs),
1036+
strict=True,
10341037
)
10351038
):
10361039
outer_mitsot = copy_var_format(_outer_mitsot, as_var=inner_mitsots[ipos])
@@ -1083,6 +1086,7 @@ def make_node(self, *inputs):
10831086
self.inner_sitsot(self.inner_inputs),
10841087
self.outer_sitsot(inputs),
10851088
self.inner_sitsot_outs(self.inner_outputs),
1089+
strict=True,
10861090
)
10871091
):
10881092
outer_sitsot = copy_var_format(_outer_sitsot, as_var=inner_sitsot)
@@ -1130,6 +1134,7 @@ def make_node(self, *inputs):
11301134
self.inner_shared(self.inner_inputs),
11311135
self.inner_shared_outs(self.inner_outputs),
11321136
self.outer_shared(inputs),
1137+
strict=True,
11331138
)
11341139
):
11351140
outer_shared = copy_var_format(_outer_shared, as_var=inner_shared)
@@ -1188,7 +1193,9 @@ def make_node(self, *inputs):
11881193
# type of tensor as the output, it is always a scalar int.
11891194
new_inputs += [as_tensor_variable(ons) for ons in self.outer_nitsot(inputs)]
11901195
for inner_nonseq, _outer_nonseq in zip(
1191-
self.inner_non_seqs(self.inner_inputs), self.outer_non_seqs(inputs)
1196+
self.inner_non_seqs(self.inner_inputs),
1197+
self.outer_non_seqs(inputs),
1198+
strict=True,
11921199
):
11931200
outer_nonseq = copy_var_format(_outer_nonseq, as_var=inner_nonseq)
11941201
new_inputs.append(outer_nonseq)
@@ -1271,7 +1278,9 @@ def __eq__(self, other):
12711278
if len(self.inner_outputs) != len(other.inner_outputs):
12721279
return False
12731280

1274-
for self_in, other_in in zip(self.inner_inputs, other.inner_inputs):
1281+
for self_in, other_in in zip(
1282+
self.inner_inputs, other.inner_inputs, strict=True
1283+
):
12751284
if self_in.type != other_in.type:
12761285
return False
12771286

@@ -1406,7 +1415,7 @@ def prepare_fgraph(self, fgraph):
14061415
fgraph.attach_feature(
14071416
Supervisor(
14081417
inp
1409-
for spec, inp in zip(wrapped_inputs, fgraph.inputs)
1418+
for spec, inp in zip(wrapped_inputs, fgraph.inputs, strict=True)
14101419
if not (
14111420
getattr(spec, "mutable", None)
14121421
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([inp]))
@@ -2086,7 +2095,9 @@ def perform(self, node, inputs, output_storage):
20862095
jout = j + offset_out
20872096
output_storage[j][0] = inner_output_storage[jout].storage[0]
20882097

2089-
pos = [(idx + 1) % store for idx, store in zip(pos, store_steps)]
2098+
pos = [
2099+
(idx + 1) % store for idx, store in zip(pos, store_steps, strict=True)
2100+
]
20902101
i = i + 1
20912102

20922103
# 6. Check if you need to re-order output buffers
@@ -2171,7 +2182,7 @@ def perform(self, node, inputs, output_storage):
21712182

21722183
def infer_shape(self, fgraph, node, input_shapes):
21732184
# input_shapes correspond to the shapes of node.inputs
2174-
for inp, inp_shp in zip(node.inputs, input_shapes):
2185+
for inp, inp_shp in zip(node.inputs, input_shapes, strict=True):
21752186
assert inp_shp is None or len(inp_shp) == inp.type.ndim
21762187

21772188
# Here we build 2 variables;
@@ -2240,7 +2251,9 @@ def infer_shape(self, fgraph, node, input_shapes):
22402251
# Non-sequences have a direct equivalent from self.inner_inputs in
22412252
# node.inputs
22422253
inner_non_sequences = self.inner_inputs[len(seqs_shape) + len(outs_shape) :]
2243-
for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]):
2254+
for in_ns, out_ns in zip(
2255+
inner_non_sequences, node.inputs[offset:], strict=True
2256+
):
22442257
out_equivalent[in_ns] = out_ns
22452258

22462259
if info.as_while:
@@ -2275,7 +2288,7 @@ def infer_shape(self, fgraph, node, input_shapes):
22752288
r = node.outputs[n_outs + x]
22762289
assert r.ndim == 1 + len(out_shape_x)
22772290
shp = [node.inputs[offset + info.n_shared_outs + x]]
2278-
for i, shp_i in zip(range(1, r.ndim), out_shape_x):
2291+
for i, shp_i in zip(range(1, r.ndim), out_shape_x, strict=True):
22792292
# Validate shp_i. v_shape_i is either None (if invalid),
22802293
# or a (variable, Boolean) tuple. The Boolean indicates
22812294
# whether variable is shp_i (if True), or an valid
@@ -2297,7 +2310,7 @@ def infer_shape(self, fgraph, node, input_shapes):
22972310
if info.as_while:
22982311
scan_outs_init = scan_outs
22992312
scan_outs = []
2300-
for o, x in zip(node.outputs, scan_outs_init):
2313+
for o, x in zip(node.outputs, scan_outs_init, strict=True):
23012314
if x is None:
23022315
scan_outs.append(None)
23032316
else:
@@ -2573,7 +2586,9 @@ def compute_all_gradients(known_grads):
25732586
dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx])
25742587
else:
25752588
disconnected_dC_dinps_t[dx] = False
2576-
for Xt, Xt_placeholder in zip(diff_outputs[info.n_mit_mot_outs :], Xts):
2589+
for Xt, Xt_placeholder in zip(
2590+
diff_outputs[info.n_mit_mot_outs :], Xts, strict=True
2591+
):
25772592
tmp = forced_replace(dC_dinps_t[dx], Xt, Xt_placeholder)
25782593
dC_dinps_t[dx] = tmp
25792594

@@ -2653,7 +2668,9 @@ def compute_all_gradients(known_grads):
26532668
n = n_steps.tag.test_value
26542669
else:
26552670
n = inputs[0].tag.test_value
2656-
for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs)):
2671+
for taps, x in zip(
2672+
info.mit_sot_in_slices, self.outer_mitsot_outs(outs), strict=True
2673+
):
26572674
mintap = np.min(taps)
26582675
if hasattr(x[::-1][:mintap], "test_value"):
26592676
assert x[::-1][:mintap].tag.test_value.shape[0] == n
@@ -2668,7 +2685,9 @@ def compute_all_gradients(known_grads):
26682685
assert x[::-1].tag.test_value.shape[0] == n
26692686
outer_inp_seqs += [
26702687
x[::-1][: np.min(taps)]
2671-
for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs))
2688+
for taps, x in zip(
2689+
info.mit_sot_in_slices, self.outer_mitsot_outs(outs), strict=True
2690+
)
26722691
]
26732692
outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)]
26742693
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
@@ -2999,6 +3018,7 @@ def compute_all_gradients(known_grads):
29993018
zip(
30003019
outputs[offset : offset + info.n_seqs],
30013020
type_outs[offset : offset + info.n_seqs],
3021+
strict=True,
30023022
)
30033023
):
30043024
if t == "connected":
@@ -3028,7 +3048,7 @@ def compute_all_gradients(known_grads):
30283048
gradients.append(NullType(t)())
30293049

30303050
end = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
3031-
for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])):
3051+
for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end], strict=True)):
30323052
if t == "connected":
30333053
# If the forward scan is in as_while mode, we need to pad
30343054
# the gradients, so that they match the size of the input
@@ -3063,7 +3083,7 @@ def compute_all_gradients(known_grads):
30633083
for idx in range(info.n_shared_outs):
30643084
disconnected = True
30653085
connected_flags = self.connection_pattern(node)[idx + start]
3066-
for dC_dout, connected in zip(dC_douts, connected_flags):
3086+
for dC_dout, connected in zip(dC_douts, connected_flags, strict=True):
30673087
if not isinstance(dC_dout.type, DisconnectedType) and connected:
30683088
disconnected = False
30693089
if disconnected:
@@ -3080,7 +3100,9 @@ def compute_all_gradients(known_grads):
30803100
begin = end
30813101

30823102
end = begin + n_sitsot_outs
3083-
for p, (x, t) in enumerate(zip(outputs[begin:end], type_outs[begin:end])):
3103+
for p, (x, t) in enumerate(
3104+
zip(outputs[begin:end], type_outs[begin:end], strict=True)
3105+
):
30843106
if t == "connected":
30853107
gradients.append(x[-1])
30863108
elif t == "disconnected":
@@ -3157,7 +3179,7 @@ def R_op(self, inputs, eval_points):
31573179
e = 1 + info.n_seqs
31583180
ie = info.n_seqs
31593181
clean_eval_points = []
3160-
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
3182+
for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True):
31613183
if evp is not None:
31623184
clean_eval_points.append(evp)
31633185
else:
@@ -3172,7 +3194,7 @@ def R_op(self, inputs, eval_points):
31723194
ib = ie
31733195
ie = ie + int(sum(len(x) for x in info.mit_mot_in_slices))
31743196
clean_eval_points = []
3175-
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
3197+
for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True):
31763198
if evp is not None:
31773199
clean_eval_points.append(evp)
31783200
else:
@@ -3187,7 +3209,7 @@ def R_op(self, inputs, eval_points):
31873209
ib = ie
31883210
ie = ie + int(sum(len(x) for x in info.mit_sot_in_slices))
31893211
clean_eval_points = []
3190-
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
3212+
for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True):
31913213
if evp is not None:
31923214
clean_eval_points.append(evp)
31933215
else:
@@ -3202,7 +3224,7 @@ def R_op(self, inputs, eval_points):
32023224
ib = ie
32033225
ie = ie + info.n_sit_sot
32043226
clean_eval_points = []
3205-
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
3227+
for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True):
32063228
if evp is not None:
32073229
clean_eval_points.append(evp)
32083230
else:
@@ -3226,7 +3248,7 @@ def R_op(self, inputs, eval_points):
32263248

32273249
# All other arguments
32283250
clean_eval_points = []
3229-
for inp, evp in zip(inputs[e:], eval_points[e:]):
3251+
for inp, evp in zip(inputs[e:], eval_points[e:], strict=True):
32303252
if evp is not None:
32313253
clean_eval_points.append(evp)
32323254
else:

0 commit comments

Comments
 (0)