Skip to content

Commit 6891d06

Browse files
committed
Make zips strict in pytensor/tensor
1 parent c0779ec commit 6891d06

File tree

14 files changed

+157
-89
lines changed

14 files changed

+157
-89
lines changed

pytensor/tensor/basic.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,7 @@ def make_node(self, value, *shape):
15241524
extended_value_broadcastable,
15251525
extended_value_static_shape,
15261526
static_shape,
1527+
strict=True,
15271528
)
15281529
):
15291530
# If value is not broadcastable and we don't know the target static shape: use value static shape
@@ -1544,7 +1545,7 @@ def make_node(self, value, *shape):
15441545
def _check_runtime_broadcast(node, value, shape):
15451546
value_static_shape = node.inputs[0].type.shape
15461547
for v_static_dim, value_dim, out_dim in zip(
1547-
value_static_shape[::-1], value.shape[::-1], shape[::-1]
1548+
value_static_shape[::-1], value.shape[::-1], shape[::-1], strict=True
15481549
):
15491550
if v_static_dim is None and value_dim == 1 and out_dim != 1:
15501551
raise ValueError(Alloc._runtime_broadcast_error_msg)
@@ -1650,6 +1651,7 @@ def grad(self, inputs, grads):
16501651
inputs[0].type.shape,
16511652
# We need the dimensions corresponding to x
16521653
grads[0].type.shape[-inputs[0].ndim :],
1654+
strict=False,
16531655
)
16541656
):
16551657
if ib == 1 and gb != 1:
@@ -2162,7 +2164,7 @@ def grad(self, inputs, g_outputs):
21622164
]
21632165
# Else, we have to make them zeros before joining them
21642166
new_g_outputs = []
2165-
for o, g in zip(outputs, g_outputs):
2167+
for o, g in zip(outputs, g_outputs, strict=True):
21662168
if isinstance(g.type, DisconnectedType):
21672169
new_g_outputs.append(o.zeros_like())
21682170
else:
@@ -2593,7 +2595,7 @@ def grad(self, axis_and_tensors, grads):
25932595
else specify_broadcastable(
25942596
g, *(ax for (ax, s) in enumerate(t.type.shape) if s == 1)
25952597
)
2596-
for t, g in zip(tens, split_gz)
2598+
for t, g in zip(tens, split_gz, strict=True)
25972599
]
25982600
rval = rval + split_gz
25992601
else:
@@ -2701,7 +2703,7 @@ def vectorize_join(op: Join, node, batch_axis, *batch_inputs):
27012703
):
27022704
batch_ndims = {
27032705
batch_input.type.ndim - old_input.type.ndim
2704-
for batch_input, old_input in zip(batch_inputs, old_inputs)
2706+
for batch_input, old_input in zip(batch_inputs, old_inputs, strict=True)
27052707
}
27062708
if len(batch_ndims) == 1:
27072709
[batch_ndim] = batch_ndims
@@ -3283,7 +3285,7 @@ def __getitem__(self, *args):
32833285
tuple([1] * j + [r.shape[0]] + [1] * (ndim - 1 - j))
32843286
for j, r in enumerate(ranges)
32853287
]
3286-
ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes)]
3288+
ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes, strict=True)]
32873289
if self.sparse:
32883290
grids = ranges
32893291
else:
@@ -3355,7 +3357,7 @@ def make_node(self, x, y, inverse):
33553357

33563358
out_shape = [
33573359
1 if xb == 1 and yb == 1 else None
3358-
for xb, yb in zip(x.type.shape, y.type.shape)
3360+
for xb, yb in zip(x.type.shape, y.type.shape, strict=True)
33593361
]
33603362
out_type = tensor(dtype=x.type.dtype, shape=out_shape)
33613363

@@ -3420,7 +3422,7 @@ def perform(self, node, inp, out):
34203422

34213423
# Make sure the output is big enough
34223424
out_s = []
3423-
for xdim, ydim in zip(x_s, y_s):
3425+
for xdim, ydim in zip(x_s, y_s, strict=True):
34243426
if xdim == ydim:
34253427
outdim = xdim
34263428
elif xdim == 1:
@@ -3482,7 +3484,7 @@ def grad(self, inp, grads):
34823484
assert gx.type.ndim == x.type.ndim
34833485
assert all(
34843486
s1 == s2
3485-
for s1, s2 in zip(gx.type.shape, x.type.shape)
3487+
for s1, s2 in zip(gx.type.shape, x.type.shape, strict=True)
34863488
if s1 == 1 or s2 == 1
34873489
)
34883490

@@ -3992,7 +3994,7 @@ def moveaxis(
39923994

39933995
order = [n for n in range(a.ndim) if n not in source]
39943996

3995-
for dest, src in sorted(zip(destination, source)):
3997+
for dest, src in sorted(zip(destination, source, strict=True)):
39963998
order.insert(dest, src)
39973999

39984000
result = a.dimshuffle(order)
@@ -4346,7 +4348,7 @@ def _make_along_axis_idx(arr_shape, indices, axis):
43464348
# build a fancy index, consisting of orthogonal aranges, with the
43474349
# requested index inserted at the right location
43484350
fancy_index = []
4349-
for dim, n in zip(dest_dims, arr_shape):
4351+
for dim, n in zip(dest_dims, arr_shape, strict=True):
43504352
if dim is None:
43514353
fancy_index.append(indices)
43524354
else:

pytensor/tensor/blockwise.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __getstate__(self):
8888

8989
def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
9090
core_input_types = []
91-
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
91+
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)):
9292
if inp.type.ndim < len(sig):
9393
raise ValueError(
9494
f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}"
@@ -106,7 +106,9 @@ def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
106106
raise ValueError(
107107
f"Insufficient number of outputs for signature {self.signature}: {len(core_node.outputs)}"
108108
)
109-
for i, (core_out, sig) in enumerate(zip(core_node.outputs, self.outputs_sig)):
109+
for i, (core_out, sig) in enumerate(
110+
zip(core_node.outputs, self.outputs_sig, strict=True)
111+
):
110112
if core_out.type.ndim != len(sig):
111113
raise ValueError(
112114
f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}"
@@ -120,12 +122,13 @@ def make_node(self, *inputs):
120122
core_node = self._create_dummy_core_node(inputs)
121123

122124
batch_ndims = max(
123-
inp.type.ndim - len(sig) for inp, sig in zip(inputs, self.inputs_sig)
125+
inp.type.ndim - len(sig)
126+
for inp, sig in zip(inputs, self.inputs_sig, strict=True)
124127
)
125128

126129
batched_inputs = []
127130
batch_shapes = []
128-
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
131+
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)):
129132
# Append missing dims to the left
130133
missing_batch_ndims = batch_ndims - (inp.type.ndim - len(sig))
131134
if missing_batch_ndims:
@@ -141,7 +144,7 @@ def make_node(self, *inputs):
141144
batch_shape = tuple(
142145
[
143146
broadcast_static_dim_lengths(batch_dims)
144-
for batch_dims in zip(*batch_shapes)
147+
for batch_dims in zip(*batch_shapes, strict=True)
145148
]
146149
)
147150
except ValueError:
@@ -168,11 +171,11 @@ def infer_shape(
168171
batch_ndims = self.batch_ndim(node)
169172
core_dims: dict[str, Any] = {}
170173
batch_shapes = []
171-
for input_shape, sig in zip(input_shapes, self.inputs_sig):
174+
for input_shape, sig in zip(input_shapes, self.inputs_sig, strict=True):
172175
batch_shapes.append(input_shape[:batch_ndims])
173176
core_shape = input_shape[batch_ndims:]
174177

175-
for core_dim, dim_name in zip(core_shape, sig):
178+
for core_dim, dim_name in zip(core_shape, sig, strict=True):
176179
prev_core_dim = core_dims.get(core_dim)
177180
if prev_core_dim is None:
178181
core_dims[dim_name] = core_dim
@@ -183,7 +186,7 @@ def infer_shape(
183186
batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True)
184187

185188
out_shapes = []
186-
for output, sig in zip(node.outputs, self.outputs_sig):
189+
for output, sig in zip(node.outputs, self.outputs_sig, strict=True):
187190
core_out_shape = []
188191
for i, dim_name in enumerate(sig):
189192
# The output dim is the same as another input dim
@@ -214,17 +217,17 @@ def as_core(t, core_t):
214217
with config.change_flags(compute_test_value="off"):
215218
safe_inputs = [
216219
tensor(dtype=inp.type.dtype, shape=(None,) * len(sig))
217-
for inp, sig in zip(inputs, self.inputs_sig)
220+
for inp, sig in zip(inputs, self.inputs_sig, strict=True)
218221
]
219222
core_node = self._create_dummy_core_node(safe_inputs)
220223

221224
core_inputs = [
222225
as_core(inp, core_inp)
223-
for inp, core_inp in zip(inputs, core_node.inputs)
226+
for inp, core_inp in zip(inputs, core_node.inputs, strict=True)
224227
]
225228
core_ograds = [
226229
as_core(ograd, core_ograd)
227-
for ograd, core_ograd in zip(ograds, core_node.outputs)
230+
for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True)
228231
]
229232
core_outputs = core_node.outputs
230233

@@ -233,7 +236,11 @@ def as_core(t, core_t):
233236
igrads = vectorize_graph(
234237
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
235238
replace=dict(
236-
zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds)
239+
zip(
240+
core_inputs + core_outputs + core_ograds,
241+
inputs + outputs + ograds,
242+
strict=True,
243+
)
237244
),
238245
)
239246

@@ -259,7 +266,7 @@ def L_op(self, inputs, outs, ograds):
259266
# the return value obviously zero so that gradient.grad can tell
260267
# this op did the right thing.
261268
new_rval = []
262-
for elem, inp in zip(rval, inputs):
269+
for elem, inp in zip(rval, inputs, strict=True):
263270
if isinstance(elem.type, NullType | DisconnectedType):
264271
new_rval.append(elem)
265272
else:
@@ -273,15 +280,17 @@ def L_op(self, inputs, outs, ograds):
273280
# Sum out the broadcasted dimensions
274281
batch_ndims = self.batch_ndim(outs[0].owner)
275282
batch_shape = outs[0].type.shape[:batch_ndims]
276-
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
283+
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)):
277284
if isinstance(rval[i].type, NullType | DisconnectedType):
278285
continue
279286

280287
assert inp.type.ndim == batch_ndims + len(sig)
281288

282289
to_sum = [
283290
j
284-
for j, (inp_s, out_s) in enumerate(zip(inp.type.shape, batch_shape))
291+
for j, (inp_s, out_s) in enumerate(
292+
zip(inp.type.shape, batch_shape, strict=True)
293+
)
285294
if inp_s == 1 and out_s != 1
286295
]
287296
if to_sum:
@@ -321,9 +330,14 @@ def _check_runtime_broadcast(self, node, inputs):
321330

322331
for dims_and_bcast in zip(
323332
*[
324-
zip(input.shape[:batch_ndim], sinput.type.broadcastable[:batch_ndim])
325-
for input, sinput in zip(inputs, node.inputs)
326-
]
333+
zip(
334+
input.shape[:batch_ndim],
335+
sinput.type.broadcastable[:batch_ndim],
336+
strict=True,
337+
)
338+
for input, sinput in zip(inputs, node.inputs, strict=True)
339+
],
340+
strict=True,
327341
):
328342
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
329343
raise ValueError(
@@ -344,7 +358,9 @@ def perform(self, node, inputs, output_storage):
344358
if not isinstance(res, tuple):
345359
res = (res,)
346360

347-
for node_out, out_storage, r in zip(node.outputs, output_storage, res):
361+
for node_out, out_storage, r in zip(
362+
node.outputs, output_storage, res, strict=True
363+
):
348364
out_dtype = getattr(node_out, "dtype", None)
349365
if out_dtype and out_dtype != r.dtype:
350366
r = np.asarray(r, dtype=out_dtype)

pytensor/tensor/conv/abstract_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def check_dim(given, computed):
506506

507507
return all(
508508
check_dim(given, computed)
509-
for (given, computed) in zip(output_shape, computed_output_shape)
509+
for (given, computed) in zip(output_shape, computed_output_shape, strict=True)
510510
)
511511

512512

0 commit comments

Comments
 (0)