Skip to content

Commit 34d138c

Browse files
committed
Make zips strict in pytensor/tensor
1 parent ca14899 commit 34d138c

File tree

14 files changed

+157
-89
lines changed

14 files changed

+157
-89
lines changed

pytensor/tensor/basic.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,7 @@ def make_node(self, value, *shape):
15241524
extended_value_broadcastable,
15251525
extended_value_static_shape,
15261526
static_shape,
1527+
strict=True,
15271528
)
15281529
):
15291530
# If value is not broadcastable and we don't know the target static shape: use value static shape
@@ -1544,7 +1545,7 @@ def make_node(self, value, *shape):
15441545
def _check_runtime_broadcast(node, value, shape):
15451546
value_static_shape = node.inputs[0].type.shape
15461547
for v_static_dim, value_dim, out_dim in zip(
1547-
value_static_shape[::-1], value.shape[::-1], shape[::-1]
1548+
value_static_shape[::-1], value.shape[::-1], shape[::-1], strict=False
15481549
):
15491550
if v_static_dim is None and value_dim == 1 and out_dim != 1:
15501551
raise ValueError(Alloc._runtime_broadcast_error_msg)
@@ -1647,6 +1648,7 @@ def grad(self, inputs, grads):
16471648
inputs[0].type.shape,
16481649
# We need the dimensions corresponding to x
16491650
grads[0].type.shape[-inputs[0].ndim :],
1651+
strict=False,
16501652
)
16511653
):
16521654
if ib == 1 and gb != 1:
@@ -2157,7 +2159,7 @@ def grad(self, inputs, g_outputs):
21572159
]
21582160
# Else, we have to make them zeros before joining them
21592161
new_g_outputs = []
2160-
for o, g in zip(outputs, g_outputs):
2162+
for o, g in zip(outputs, g_outputs, strict=True):
21612163
if isinstance(g.type, DisconnectedType):
21622164
new_g_outputs.append(o.zeros_like())
21632165
else:
@@ -2586,7 +2588,7 @@ def grad(self, axis_and_tensors, grads):
25862588
else specify_broadcastable(
25872589
g, *(ax for (ax, s) in enumerate(t.type.shape) if s == 1)
25882590
)
2589-
for t, g in zip(tens, split_gz)
2591+
for t, g in zip(tens, split_gz, strict=True)
25902592
]
25912593
rval = rval + split_gz
25922594
else:
@@ -2694,7 +2696,7 @@ def vectorize_join(op: Join, node, batch_axis, *batch_inputs):
26942696
):
26952697
batch_ndims = {
26962698
batch_input.type.ndim - old_input.type.ndim
2697-
for batch_input, old_input in zip(batch_inputs, old_inputs)
2699+
for batch_input, old_input in zip(batch_inputs, old_inputs, strict=True)
26982700
}
26992701
if len(batch_ndims) == 1:
27002702
[batch_ndim] = batch_ndims
@@ -3276,7 +3278,7 @@ def __getitem__(self, *args):
32763278
tuple([1] * j + [r.shape[0]] + [1] * (ndim - 1 - j))
32773279
for j, r in enumerate(ranges)
32783280
]
3279-
ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes)]
3281+
ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes, strict=True)]
32803282
if self.sparse:
32813283
grids = ranges
32823284
else:
@@ -3348,7 +3350,7 @@ def make_node(self, x, y, inverse):
33483350

33493351
out_shape = [
33503352
1 if xb == 1 and yb == 1 else None
3351-
for xb, yb in zip(x.type.shape, y.type.shape)
3353+
for xb, yb in zip(x.type.shape, y.type.shape, strict=True)
33523354
]
33533355
out_type = tensor(dtype=x.type.dtype, shape=out_shape)
33543356

@@ -3413,7 +3415,7 @@ def perform(self, node, inp, out):
34133415

34143416
# Make sure the output is big enough
34153417
out_s = []
3416-
for xdim, ydim in zip(x_s, y_s):
3418+
for xdim, ydim in zip(x_s, y_s, strict=True):
34173419
if xdim == ydim:
34183420
outdim = xdim
34193421
elif xdim == 1:
@@ -3473,7 +3475,7 @@ def grad(self, inp, grads):
34733475
assert gx.type.ndim == x.type.ndim
34743476
assert all(
34753477
s1 == s2
3476-
for s1, s2 in zip(gx.type.shape, x.type.shape)
3478+
for s1, s2 in zip(gx.type.shape, x.type.shape, strict=True)
34773479
if s1 == 1 or s2 == 1
34783480
)
34793481

@@ -3983,7 +3985,7 @@ def moveaxis(
39833985

39843986
order = [n for n in range(a.ndim) if n not in source]
39853987

3986-
for dest, src in sorted(zip(destination, source)):
3988+
for dest, src in sorted(zip(destination, source, strict=True)):
39873989
order.insert(dest, src)
39883990

39893991
result = a.dimshuffle(order)
@@ -4337,7 +4339,7 @@ def _make_along_axis_idx(arr_shape, indices, axis):
43374339
# build a fancy index, consisting of orthogonal aranges, with the
43384340
# requested index inserted at the right location
43394341
fancy_index = []
4340-
for dim, n in zip(dest_dims, arr_shape):
4342+
for dim, n in zip(dest_dims, arr_shape, strict=True):
43414343
if dim is None:
43424344
fancy_index.append(indices)
43434345
else:

pytensor/tensor/blockwise.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __getstate__(self):
8888

8989
def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
9090
core_input_types = []
91-
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
91+
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)):
9292
if inp.type.ndim < len(sig):
9393
raise ValueError(
9494
f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}"
@@ -106,7 +106,9 @@ def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
106106
raise ValueError(
107107
f"Insufficient number of outputs for signature {self.signature}: {len(core_node.outputs)}"
108108
)
109-
for i, (core_out, sig) in enumerate(zip(core_node.outputs, self.outputs_sig)):
109+
for i, (core_out, sig) in enumerate(
110+
zip(core_node.outputs, self.outputs_sig, strict=True)
111+
):
110112
if core_out.type.ndim != len(sig):
111113
raise ValueError(
112114
f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}"
@@ -120,12 +122,13 @@ def make_node(self, *inputs):
120122
core_node = self._create_dummy_core_node(inputs)
121123

122124
batch_ndims = max(
123-
inp.type.ndim - len(sig) for inp, sig in zip(inputs, self.inputs_sig)
125+
inp.type.ndim - len(sig)
126+
for inp, sig in zip(inputs, self.inputs_sig, strict=True)
124127
)
125128

126129
batched_inputs = []
127130
batch_shapes = []
128-
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
131+
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)):
129132
# Append missing dims to the left
130133
missing_batch_ndims = batch_ndims - (inp.type.ndim - len(sig))
131134
if missing_batch_ndims:
@@ -141,7 +144,7 @@ def make_node(self, *inputs):
141144
batch_shape = tuple(
142145
[
143146
broadcast_static_dim_lengths(batch_dims)
144-
for batch_dims in zip(*batch_shapes)
147+
for batch_dims in zip(*batch_shapes, strict=True)
145148
]
146149
)
147150
except ValueError:
@@ -168,10 +171,10 @@ def infer_shape(
168171
batch_ndims = self.batch_ndim(node)
169172
core_dims: dict[str, Any] = {}
170173
batch_shapes = [input_shape[:batch_ndims] for input_shape in input_shapes]
171-
for input_shape, sig in zip(input_shapes, self.inputs_sig):
174+
for input_shape, sig in zip(input_shapes, self.inputs_sig, strict=True):
172175
core_shape = input_shape[batch_ndims:]
173176

174-
for core_dim, dim_name in zip(core_shape, sig):
177+
for core_dim, dim_name in zip(core_shape, sig, strict=True):
175178
prev_core_dim = core_dims.get(core_dim)
176179
if prev_core_dim is None:
177180
core_dims[dim_name] = core_dim
@@ -182,7 +185,7 @@ def infer_shape(
182185
batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True)
183186

184187
out_shapes = []
185-
for output, sig in zip(node.outputs, self.outputs_sig):
188+
for output, sig in zip(node.outputs, self.outputs_sig, strict=True):
186189
core_out_shape = []
187190
for i, dim_name in enumerate(sig):
188191
# The output dim is the same as another input dim
@@ -213,17 +216,17 @@ def as_core(t, core_t):
213216
with config.change_flags(compute_test_value="off"):
214217
safe_inputs = [
215218
tensor(dtype=inp.type.dtype, shape=(None,) * len(sig))
216-
for inp, sig in zip(inputs, self.inputs_sig)
219+
for inp, sig in zip(inputs, self.inputs_sig, strict=True)
217220
]
218221
core_node = self._create_dummy_core_node(safe_inputs)
219222

220223
core_inputs = [
221224
as_core(inp, core_inp)
222-
for inp, core_inp in zip(inputs, core_node.inputs)
225+
for inp, core_inp in zip(inputs, core_node.inputs, strict=True)
223226
]
224227
core_ograds = [
225228
as_core(ograd, core_ograd)
226-
for ograd, core_ograd in zip(ograds, core_node.outputs)
229+
for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True)
227230
]
228231
core_outputs = core_node.outputs
229232

@@ -232,7 +235,11 @@ def as_core(t, core_t):
232235
igrads = vectorize_graph(
233236
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
234237
replace=dict(
235-
zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds)
238+
zip(
239+
core_inputs + core_outputs + core_ograds,
240+
inputs + outputs + ograds,
241+
strict=True,
242+
)
236243
),
237244
)
238245

@@ -258,7 +265,7 @@ def L_op(self, inputs, outs, ograds):
258265
# the return value obviously zero so that gradient.grad can tell
259266
# this op did the right thing.
260267
new_rval = []
261-
for elem, inp in zip(rval, inputs):
268+
for elem, inp in zip(rval, inputs, strict=True):
262269
if isinstance(elem.type, NullType | DisconnectedType):
263270
new_rval.append(elem)
264271
else:
@@ -272,15 +279,17 @@ def L_op(self, inputs, outs, ograds):
272279
# Sum out the broadcasted dimensions
273280
batch_ndims = self.batch_ndim(outs[0].owner)
274281
batch_shape = outs[0].type.shape[:batch_ndims]
275-
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
282+
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)):
276283
if isinstance(rval[i].type, NullType | DisconnectedType):
277284
continue
278285

279286
assert inp.type.ndim == batch_ndims + len(sig)
280287

281288
to_sum = [
282289
j
283-
for j, (inp_s, out_s) in enumerate(zip(inp.type.shape, batch_shape))
290+
for j, (inp_s, out_s) in enumerate(
291+
zip(inp.type.shape, batch_shape, strict=False)
292+
)
284293
if inp_s == 1 and out_s != 1
285294
]
286295
if to_sum:
@@ -320,9 +329,14 @@ def _check_runtime_broadcast(self, node, inputs):
320329

321330
for dims_and_bcast in zip(
322331
*[
323-
zip(input.shape[:batch_ndim], sinput.type.broadcastable[:batch_ndim])
324-
for input, sinput in zip(inputs, node.inputs)
325-
]
332+
zip(
333+
input.shape[:batch_ndim],
334+
sinput.type.broadcastable[:batch_ndim],
335+
strict=True,
336+
)
337+
for input, sinput in zip(inputs, node.inputs, strict=True)
338+
],
339+
strict=True,
326340
):
327341
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
328342
raise ValueError(
@@ -343,7 +357,9 @@ def perform(self, node, inputs, output_storage):
343357
if not isinstance(res, tuple):
344358
res = (res,)
345359

346-
for node_out, out_storage, r in zip(node.outputs, output_storage, res):
360+
for node_out, out_storage, r in zip(
361+
node.outputs, output_storage, res, strict=True
362+
):
347363
out_dtype = getattr(node_out, "dtype", None)
348364
if out_dtype and out_dtype != r.dtype:
349365
r = np.asarray(r, dtype=out_dtype)

pytensor/tensor/conv/abstract_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def check_dim(given, computed):
506506

507507
return all(
508508
check_dim(given, computed)
509-
for (given, computed) in zip(output_shape, computed_output_shape)
509+
for (given, computed) in zip(output_shape, computed_output_shape, strict=True)
510510
)
511511

512512

0 commit comments

Comments
 (0)