Skip to content

Commit 0767523

Browse files
committed
Make zips strict in pytensor/link/jax
1 parent d9bc7b4 commit 0767523

File tree

4 files changed

+17
-9
lines changed

4 files changed

+17
-9
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ def scan(*outer_inputs):
3030
seqs = op.outer_seqs(outer_inputs) # JAX `xs`
3131

3232
mit_sot_init = []
33-
for tap, seq in zip(op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs)):
33+
for tap, seq in zip(
34+
op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs), strict=True
35+
):
3436
init_slice = seq[: abs(min(tap))]
3537
mit_sot_init.append(init_slice)
3638

@@ -61,7 +63,9 @@ def jax_args_to_inner_func_args(carry, x):
6163
inner_seqs = x
6264

6365
mit_sot_flatten = []
64-
for array, index in zip(inner_mit_sot, op.info.mit_sot_in_slices):
66+
for array, index in zip(
67+
inner_mit_sot, op.info.mit_sot_in_slices, strict=True
68+
):
6569
mit_sot_flatten.extend(array[jnp.array(index)])
6670

6771
inner_scan_inputs = [
@@ -98,8 +102,7 @@ def inner_func_outs_to_jax_outs(
98102
inner_mit_sot_new = [
99103
jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0)
100104
for old_mit_sot, new_val in zip(
101-
inner_mit_sot,
102-
inner_mit_sot_outs,
105+
inner_mit_sot, inner_mit_sot_outs, strict=True
103106
)
104107
]
105108

@@ -152,7 +155,9 @@ def get_partial_traces(traces):
152155
+ op.outer_nitsot(outer_inputs)
153156
)
154157
partial_traces = []
155-
for init_state, trace, buffer in zip(init_states, traces, buffers):
158+
for init_state, trace, buffer in zip(
159+
init_states, traces, buffers, strict=True
160+
):
156161
if init_state is not None:
157162
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
158163
trace = jnp.atleast_1d(trace)

pytensor/link/jax/dispatch/shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def shape_i(x):
9696
def jax_funcify_SpecifyShape(op, node, **kwargs):
9797
def specifyshape(x, *shape):
9898
assert x.ndim == len(shape)
99-
for actual, expected in zip(x.shape, shape):
99+
for actual, expected in zip(x.shape, shape, strict=True):
100100
if expected is None:
101101
continue
102102
if actual != expected:

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ def jax_funcify_Tri(op, node, **kwargs):
200200
def tri(*args):
201201
# args is N, M, k
202202
args = [
203-
x if const_x is None else const_x for x, const_x in zip(args, const_args)
203+
x if const_x is None else const_x
204+
for x, const_x in zip(args, const_args, strict=True)
204205
]
205206
return jnp.tri(*args, dtype=op.dtype)
206207

pytensor/link/jax/linker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,14 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
3535
]
3636

3737
fgraph.replace_all(
38-
zip(shared_rng_inputs, new_shared_rng_inputs),
38+
zip(shared_rng_inputs, new_shared_rng_inputs, strict=True),
3939
import_missing=True,
4040
reason="JAXLinker.fgraph_convert",
4141
)
4242

43-
for old_inp, new_inp in zip(shared_rng_inputs, new_shared_rng_inputs):
43+
for old_inp, new_inp in zip(
44+
shared_rng_inputs, new_shared_rng_inputs, strict=True
45+
):
4446
new_inp_storage = [new_inp.get_value(borrow=True)]
4547
storage_map[new_inp] = new_inp_storage
4648
old_inp_storage = storage_map.pop(old_inp)

0 commit comments

Comments
 (0)