@@ -42,7 +42,7 @@ def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True):
42
42
max_fn = maximum if use_pytensor else max
43
43
44
44
rev_extra_dims = []
45
- for ndim_param , param_shape in zip (ndims_params , param_shapes ):
45
+ for ndim_param , param_shape in zip (ndims_params , param_shapes , strict = True ):
46
46
# We need this in order to use `len`
47
47
param_shape = tuple (param_shape )
48
48
extras = tuple (param_shape [: (len (param_shape ) - ndim_param )])
@@ -65,7 +65,7 @@ def max_bcast(x, y):
65
65
(extra_dims + tuple (param_shape )[- ndim_param :])
66
66
if ndim_param > 0
67
67
else extra_dims
68
- for ndim_param , param_shape in zip (ndims_params , param_shapes )
68
+ for ndim_param , param_shape in zip (ndims_params , param_shapes , strict = True )
69
69
]
70
70
71
71
return bcast_shapes
@@ -104,7 +104,9 @@ def broadcast_params(params, ndims_params):
104
104
for p in params :
105
105
param_shape = tuple (
106
106
1 if bcast else s
107
- for s , bcast in zip (p .shape , getattr (p , "broadcastable" , (False ,) * p .ndim ))
107
+ for s , bcast in zip (
108
+ p .shape , getattr (p , "broadcastable" , (False ,) * p .ndim ), strict = True
109
+ )
108
110
)
109
111
use_pytensor |= isinstance (p , Variable )
110
112
param_shapes .append (param_shape )
@@ -115,7 +117,8 @@ def broadcast_params(params, ndims_params):
115
117
broadcast_to_fn = broadcast_to if use_pytensor else np .broadcast_to
116
118
117
119
bcast_params = [
118
- broadcast_to_fn (param , shape ) for shape , param in zip (shapes , params )
120
+ broadcast_to_fn (param , shape )
121
+ for shape , param in zip (shapes , params , strict = True )
119
122
]
120
123
121
124
return bcast_params
@@ -129,7 +132,8 @@ def explicit_expand_dims(
129
132
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
130
133
131
134
batch_dims = [
132
- param .type .ndim - ndim_param for param , ndim_param in zip (params , ndim_params )
135
+ param .type .ndim - ndim_param
136
+ for param , ndim_param in zip (params , ndim_params , strict = False )
133
137
]
134
138
135
139
if size_length is not None :
@@ -138,7 +142,7 @@ def explicit_expand_dims(
138
142
max_batch_dims = max (batch_dims , default = 0 )
139
143
140
144
new_params = []
141
- for new_param , batch_dim in zip (params , batch_dims ):
145
+ for new_param , batch_dim in zip (params , batch_dims , strict = True ):
142
146
missing_dims = max_batch_dims - batch_dim
143
147
if missing_dims :
144
148
new_param = shape_padleft (new_param , missing_dims )
@@ -153,7 +157,7 @@ def compute_batch_shape(
153
157
params = explicit_expand_dims (params , ndims_params )
154
158
batch_params = [
155
159
param [(..., * (0 ,) * core_ndim )]
156
- for param , core_ndim in zip (params , ndims_params )
160
+ for param , core_ndim in zip (params , ndims_params , strict = True )
157
161
]
158
162
return broadcast_arrays (* batch_params )[0 ].shape
159
163
@@ -269,7 +273,9 @@ def seed(self, seed=None):
269
273
self .gen_seedgen = np .random .SeedSequence (seed )
270
274
old_r_seeds = self .gen_seedgen .spawn (len (self .state_updates ))
271
275
272
- for (old_r , new_r ), old_r_seed in zip (self .state_updates , old_r_seeds ):
276
+ for (old_r , new_r ), old_r_seed in zip (
277
+ self .state_updates , old_r_seeds , strict = True
278
+ ):
273
279
old_r .set_value (self .rng_ctor (old_r_seed ), borrow = True )
274
280
275
281
def gen (self , op : "RandomVariable" , * args , ** kwargs ) -> TensorVariable :
0 commit comments