Skip to content

Commit 9a3e9e3

Browse files
committed
Make zips strict in pytensor/tensor/random
1 parent 9c0d35d commit 9a3e9e3

File tree

4 files changed

+27
-17
lines changed

4 files changed

+27
-17
lines changed

pytensor/tensor/random/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1862,7 +1862,7 @@ def rng_fn(cls, rng, p, size):
18621862
# to `p.shape[:-1]` in the call to `vsearchsorted` below.
18631863
if len(size) < (p.ndim - 1):
18641864
raise ValueError("`size` is incompatible with the shape of `p`")
1865-
for s, ps in zip(reversed(size), reversed(p.shape[:-1])):
1865+
for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=True):
18661866
if s == 1 and ps != 1:
18671867
raise ValueError("`size` is incompatible with the shape of `p`")
18681868

pytensor/tensor/random/op.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,13 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):
152152

153153
# Try to infer missing support dims from signature of params
154154
for param, param_sig, ndim_params in zip(
155-
dist_params, self.inputs_sig, self.ndims_params
155+
dist_params, self.inputs_sig, self.ndims_params, strict=True
156156
):
157157
if ndim_params == 0:
158158
continue
159-
for param_dim, dim in zip(param.shape[-ndim_params:], param_sig):
159+
for param_dim, dim in zip(
160+
param.shape[-ndim_params:], param_sig, strict=True
161+
):
160162
if dim in core_out_shape and core_out_shape[dim] is None:
161163
core_out_shape[dim] = param_dim
162164

@@ -231,7 +233,7 @@ def _infer_shape(
231233

232234
# Fail early when size is incompatible with parameters
233235
for i, (param, param_ndim_supp) in enumerate(
234-
zip(dist_params, self.ndims_params)
236+
zip(dist_params, self.ndims_params, strict=True)
235237
):
236238
param_batched_dims = getattr(param, "ndim", 0) - param_ndim_supp
237239
if param_batched_dims > size_len:
@@ -255,7 +257,7 @@ def extract_batch_shape(p, ps, n):
255257

256258
batch_shape = tuple(
257259
s if not b else constant(1, "int64")
258-
for s, b in zip(shape[:-n], p.type.broadcastable[:-n])
260+
for s, b in zip(shape[:-n], p.type.broadcastable[:-n], strict=True)
259261
)
260262
return batch_shape
261263

@@ -264,7 +266,9 @@ def extract_batch_shape(p, ps, n):
264266
# independent variate dimensions are left.
265267
params_batch_shape = tuple(
266268
extract_batch_shape(p, ps, n)
267-
for p, ps, n in zip(dist_params, param_shapes, self.ndims_params)
269+
for p, ps, n in zip(
270+
dist_params, param_shapes, self.ndims_params, strict=False
271+
)
268272
)
269273

270274
if len(params_batch_shape) == 1:

pytensor/tensor/random/rewriting/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
172172

173173
# Updates the params to reflect the Dimshuffled dimensions
174174
new_dist_params = []
175-
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
175+
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params, strict=True):
176176
# Add broadcastable dimensions to the parameters that would have been expanded by the size
177177
padleft = batched_dims - (param.ndim - param_ndim_supp)
178178
if padleft > 0:
@@ -290,7 +290,7 @@ def is_nd_advanced_idx(idx, dtype):
290290
# non-broadcastable (non-degenerate) parameter dims. These parameters and the new size
291291
# should still correctly broadcast any degenerate parameter dims.
292292
new_dist_params = []
293-
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
293+
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params, strict=True):
294294
# We first expand any missing parameter dims (and later index them away or keep them with none-slicing)
295295
batch_param_dims_missing = batch_ndims - (param.ndim - param_ndim_supp)
296296
batch_param = (
@@ -302,7 +302,7 @@ def is_nd_advanced_idx(idx, dtype):
302302
bcast_batch_param_dims = tuple(
303303
dim
304304
for dim, (param_dim, output_dim) in enumerate(
305-
zip(batch_param.type.shape, rv.type.shape)
305+
zip(batch_param.type.shape, rv.type.shape, strict=False)
306306
)
307307
if (param_dim == 1) and (output_dim != 1)
308308
)

pytensor/tensor/random/utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True):
4242
max_fn = maximum if use_pytensor else max
4343

4444
rev_extra_dims = []
45-
for ndim_param, param_shape in zip(ndims_params, param_shapes):
45+
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=True):
4646
# We need this in order to use `len`
4747
param_shape = tuple(param_shape)
4848
extras = tuple(param_shape[: (len(param_shape) - ndim_param)])
@@ -65,7 +65,7 @@ def max_bcast(x, y):
6565
(extra_dims + tuple(param_shape)[-ndim_param:])
6666
if ndim_param > 0
6767
else extra_dims
68-
for ndim_param, param_shape in zip(ndims_params, param_shapes)
68+
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=True)
6969
]
7070

7171
return bcast_shapes
@@ -104,7 +104,9 @@ def broadcast_params(params, ndims_params):
104104
for p in params:
105105
param_shape = tuple(
106106
1 if bcast else s
107-
for s, bcast in zip(p.shape, getattr(p, "broadcastable", (False,) * p.ndim))
107+
for s, bcast in zip(
108+
p.shape, getattr(p, "broadcastable", (False,) * p.ndim), strict=True
109+
)
108110
)
109111
use_pytensor |= isinstance(p, Variable)
110112
param_shapes.append(param_shape)
@@ -115,7 +117,8 @@ def broadcast_params(params, ndims_params):
115117
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to
116118

117119
bcast_params = [
118-
broadcast_to_fn(param, shape) for shape, param in zip(shapes, params)
120+
broadcast_to_fn(param, shape)
121+
for shape, param in zip(shapes, params, strict=True)
119122
]
120123

121124
return bcast_params
@@ -129,7 +132,8 @@ def explicit_expand_dims(
129132
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
130133

131134
batch_dims = [
132-
param.type.ndim - ndim_param for param, ndim_param in zip(params, ndim_params)
135+
param.type.ndim - ndim_param
136+
for param, ndim_param in zip(params, ndim_params, strict=False)
133137
]
134138

135139
if size_length is not None:
@@ -138,7 +142,7 @@ def explicit_expand_dims(
138142
max_batch_dims = max(batch_dims, default=0)
139143

140144
new_params = []
141-
for new_param, batch_dim in zip(params, batch_dims):
145+
for new_param, batch_dim in zip(params, batch_dims, strict=True):
142146
missing_dims = max_batch_dims - batch_dim
143147
if missing_dims:
144148
new_param = shape_padleft(new_param, missing_dims)
@@ -153,7 +157,7 @@ def compute_batch_shape(
153157
params = explicit_expand_dims(params, ndims_params)
154158
batch_params = [
155159
param[(..., *(0,) * core_ndim)]
156-
for param, core_ndim in zip(params, ndims_params)
160+
for param, core_ndim in zip(params, ndims_params, strict=True)
157161
]
158162
return broadcast_arrays(*batch_params)[0].shape
159163

@@ -269,7 +273,9 @@ def seed(self, seed=None):
269273
self.gen_seedgen = np.random.SeedSequence(seed)
270274
old_r_seeds = self.gen_seedgen.spawn(len(self.state_updates))
271275

272-
for (old_r, new_r), old_r_seed in zip(self.state_updates, old_r_seeds):
276+
for (old_r, new_r), old_r_seed in zip(
277+
self.state_updates, old_r_seeds, strict=True
278+
):
273279
old_r.set_value(self.rng_ctor(old_r_seed), borrow=True)
274280

275281
def gen(self, op: "RandomVariable", *args, **kwargs) -> TensorVariable:

0 commit comments

Comments
 (0)