Skip to content

Commit 2c7d3ed

Browse files
committed
Make zips strict in pytensor/scalar
1 parent 34d138c commit 2c7d3ed

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

pytensor/scalar/basic.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,7 +1147,7 @@ def perform(self, node, inputs, output_storage):
11471147
else:
11481148
variables = from_return_values(self.impl(*inputs))
11491149
assert len(variables) == len(output_storage)
1150-
for storage, variable in zip(output_storage, variables):
1150+
for storage, variable in zip(output_storage, variables, strict=True):
11511151
storage[0] = variable
11521152

11531153
def impl(self, *inputs):
@@ -4109,7 +4109,9 @@ def c_support_code(self, **kwargs):
41094109

41104110
def c_support_code_apply(self, node, name):
41114111
rval = []
4112-
for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames):
4112+
for subnode, subnodename in zip(
4113+
self.fgraph.toposort(), self.nodenames, strict=True
4114+
):
41134115
subnode_support_code = subnode.op.c_support_code_apply(
41144116
subnode, subnodename % dict(nodename=name)
41154117
)
@@ -4215,7 +4217,7 @@ def __init__(self, inputs, outputs, name="Composite"):
42154217
res2 = pytensor.compile.rebuild_collect_shared(
42164218
inputs=outputs[0].owner.op.inputs,
42174219
outputs=outputs[0].owner.op.outputs,
4218-
replace=dict(zip(outputs[0].owner.op.inputs, res[1])),
4220+
replace=dict(zip(outputs[0].owner.op.inputs, res[1], strict=True)),
42194221
)
42204222
assert len(res2[1]) == len(outputs)
42214223
assert len(res[0]) == len(inputs)
@@ -4301,7 +4303,7 @@ def make_node(self, *inputs):
43014303
assert len(inputs) == self.nin
43024304
res = pytensor.compile.rebuild_collect_shared(
43034305
self.outputs,
4304-
replace=dict(zip(self.inputs, inputs)),
4306+
replace=dict(zip(self.inputs, inputs, strict=True)),
43054307
rebuild_strict=False,
43064308
)
43074309
# After rebuild_collect_shared, the Variable in inputs
@@ -4314,7 +4316,7 @@ def make_node(self, *inputs):
43144316

43154317
def perform(self, node, inputs, output_storage):
43164318
outputs = self.py_perform_fn(*inputs)
4317-
for storage, out_val in zip(output_storage, outputs):
4319+
for storage, out_val in zip(output_storage, outputs, strict=True):
43184320
storage[0] = out_val
43194321

43204322
def grad(self, inputs, output_grads):
@@ -4384,8 +4386,8 @@ def c_code_template(self):
43844386
def c_code(self, node, nodename, inames, onames, sub):
43854387
d = dict(
43864388
chain(
4387-
zip((f"i{int(i)}" for i in range(len(inames))), inames),
4388-
zip((f"o{int(i)}" for i in range(len(onames))), onames),
4389+
zip((f"i{int(i)}" for i in range(len(inames))), inames, strict=True),
4390+
zip((f"o{int(i)}" for i in range(len(onames))), onames, strict=True),
43894391
),
43904392
**sub,
43914393
)
@@ -4433,7 +4435,7 @@ def apply(self, fgraph):
44334435
)
44344436
# make sure we don't produce any float16.
44354437
assert not any(o.dtype == "float16" for o in new_node.outputs)
4436-
for o, no in zip(node.outputs, new_node.outputs):
4438+
for o, no in zip(node.outputs, new_node.outputs, strict=True):
44374439
mapping[o] = no
44384440

44394441
new_ins = [mapping[inp] for inp in fgraph.inputs]
@@ -4477,7 +4479,7 @@ def handle_composite(node, mapping):
44774479
new_op = node.op.clone_float32()
44784480
new_outs = new_op(*[mapping[i] for i in node.inputs], return_list=True)
44794481
assert len(new_outs) == len(node.outputs)
4480-
for o, no in zip(node.outputs, new_outs):
4482+
for o, no in zip(node.outputs, new_outs, strict=True):
44814483
mapping[o] = no
44824484

44834485

pytensor/scalar/loop.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _validate_updates(
9393
)
9494
else:
9595
update = outputs
96-
for i, u in zip(init, update):
96+
for i, u in zip(init, update, strict=False):
9797
if i.type != u.type:
9898
raise TypeError(
9999
"Init and update types must be the same: "
@@ -166,7 +166,7 @@ def make_node(self, n_steps, *inputs):
166166
# Make a new op with the right input types.
167167
res = rebuild_collect_shared(
168168
self.outputs,
169-
replace=dict(zip(self.inputs, inputs)),
169+
replace=dict(zip(self.inputs, inputs, strict=True)),
170170
rebuild_strict=False,
171171
)
172172
if self.is_while:
@@ -207,7 +207,7 @@ def perform(self, node, inputs, output_storage):
207207
for i in range(n_steps):
208208
carry = inner_fn(*carry, *constant)
209209

210-
for storage, out_val in zip(output_storage, carry):
210+
for storage, out_val in zip(output_storage, carry, strict=True):
211211
storage[0] = out_val
212212

213213
@property
@@ -295,7 +295,7 @@ def c_code_template(self):
295295

296296
# Set the carry variables to the output variables
297297
_c_code += "\n"
298-
for init, update in zip(carry_subd.values(), update_subd.values()):
298+
for init, update in zip(carry_subd.values(), update_subd.values(), strict=True):
299299
_c_code += f"{init} = {update};\n"
300300

301301
# _c_code += 'printf("%%ld\\n", i);\n'
@@ -321,8 +321,8 @@ def c_code_template(self):
321321
def c_code(self, node, nodename, inames, onames, sub):
322322
d = dict(
323323
chain(
324-
zip((f"i{int(i)}" for i in range(len(inames))), inames),
325-
zip((f"o{int(i)}" for i in range(len(onames))), onames),
324+
zip((f"i{int(i)}" for i in range(len(inames))), inames, strict=True),
325+
zip((f"o{int(i)}" for i in range(len(onames))), onames, strict=True),
326326
),
327327
**sub,
328328
)

0 commit comments

Comments
 (0)