Skip to content

Commit 78c83e0

Browse files
committed
Make zips strict in pytensor/scalar
1 parent fb9654e commit 78c83e0

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

pytensor/scalar/basic.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,7 +1147,7 @@ def perform(self, node, inputs, output_storage):
11471147
else:
11481148
variables = from_return_values(self.impl(*inputs))
11491149
assert len(variables) == len(output_storage)
1150-
for storage, variable in zip(output_storage, variables):
1150+
for storage, variable in zip(output_storage, variables, strict=True):
11511151
storage[0] = variable
11521152

11531153
def impl(self, *inputs):
@@ -4111,7 +4111,9 @@ def c_support_code(self, **kwargs):
41114111

41124112
def c_support_code_apply(self, node, name):
41134113
rval = []
4114-
for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames):
4114+
for subnode, subnodename in zip(
4115+
self.fgraph.toposort(), self.nodenames, strict=True
4116+
):
41154117
subnode_support_code = subnode.op.c_support_code_apply(
41164118
subnode, subnodename % dict(nodename=name)
41174119
)
@@ -4217,7 +4219,7 @@ def __init__(self, inputs, outputs, name="Composite"):
42174219
res2 = pytensor.compile.rebuild_collect_shared(
42184220
inputs=outputs[0].owner.op.inputs,
42194221
outputs=outputs[0].owner.op.outputs,
4220-
replace=dict(zip(outputs[0].owner.op.inputs, res[1])),
4222+
replace=dict(zip(outputs[0].owner.op.inputs, res[1], strict=True)),
42214223
)
42224224
assert len(res2[1]) == len(outputs)
42234225
assert len(res[0]) == len(inputs)
@@ -4303,7 +4305,7 @@ def make_node(self, *inputs):
43034305
assert len(inputs) == self.nin
43044306
res = pytensor.compile.rebuild_collect_shared(
43054307
self.outputs,
4306-
replace=dict(zip(self.inputs, inputs)),
4308+
replace=dict(zip(self.inputs, inputs, strict=True)),
43074309
rebuild_strict=False,
43084310
)
43094311
# After rebuild_collect_shared, the Variable in inputs
@@ -4316,7 +4318,7 @@ def make_node(self, *inputs):
43164318

43174319
def perform(self, node, inputs, output_storage):
43184320
outputs = self.py_perform_fn(*inputs)
4319-
for storage, out_val in zip(output_storage, outputs):
4321+
for storage, out_val in zip(output_storage, outputs, strict=True):
43204322
storage[0] = out_val
43214323

43224324
def grad(self, inputs, output_grads):
@@ -4386,8 +4388,8 @@ def c_code_template(self):
43864388
def c_code(self, node, nodename, inames, onames, sub):
43874389
d = dict(
43884390
chain(
4389-
zip((f"i{int(i)}" for i in range(len(inames))), inames),
4390-
zip((f"o{int(i)}" for i in range(len(onames))), onames),
4391+
zip((f"i{int(i)}" for i in range(len(inames))), inames, strict=True),
4392+
zip((f"o{int(i)}" for i in range(len(onames))), onames, strict=True),
43914393
),
43924394
**sub,
43934395
)
@@ -4435,7 +4437,7 @@ def apply(self, fgraph):
44354437
)
44364438
# make sure we don't produce any float16.
44374439
assert not any(o.dtype == "float16" for o in new_node.outputs)
4438-
for o, no in zip(node.outputs, new_node.outputs):
4440+
for o, no in zip(node.outputs, new_node.outputs, strict=True):
44394441
mapping[o] = no
44404442

44414443
new_ins = [mapping[inp] for inp in fgraph.inputs]
@@ -4479,7 +4481,7 @@ def handle_composite(node, mapping):
44794481
new_op = node.op.clone_float32()
44804482
new_outs = new_op(*[mapping[i] for i in node.inputs], return_list=True)
44814483
assert len(new_outs) == len(node.outputs)
4482-
for o, no in zip(node.outputs, new_outs):
4484+
for o, no in zip(node.outputs, new_outs, strict=True):
44834485
mapping[o] = no
44844486

44854487

pytensor/scalar/loop.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _validate_updates(
9393
)
9494
else:
9595
update = outputs
96-
for i, u in zip(init, update):
96+
for i, u in zip(init, update, strict=True):
9797
if i.type != u.type:
9898
raise TypeError(
9999
"Init and update types must be the same: "
@@ -166,7 +166,7 @@ def make_node(self, n_steps, *inputs):
166166
# Make a new op with the right input types.
167167
res = rebuild_collect_shared(
168168
self.outputs,
169-
replace=dict(zip(self.inputs, inputs)),
169+
replace=dict(zip(self.inputs, inputs, strict=True)),
170170
rebuild_strict=False,
171171
)
172172
if self.is_while:
@@ -207,7 +207,7 @@ def perform(self, node, inputs, output_storage):
207207
for i in range(n_steps):
208208
carry = inner_fn(*carry, *constant)
209209

210-
for storage, out_val in zip(output_storage, carry):
210+
for storage, out_val in zip(output_storage, carry, strict=True):
211211
storage[0] = out_val
212212

213213
@property
@@ -295,7 +295,7 @@ def c_code_template(self):
295295

296296
# Set the carry variables to the output variables
297297
_c_code += "\n"
298-
for init, update in zip(carry_subd.values(), update_subd.values()):
298+
for init, update in zip(carry_subd.values(), update_subd.values(), strict=True):
299299
_c_code += f"{init} = {update};\n"
300300

301301
# _c_code += 'printf("%%ld\\n", i);\n'
@@ -321,8 +321,8 @@ def c_code_template(self):
321321
def c_code(self, node, nodename, inames, onames, sub):
322322
d = dict(
323323
chain(
324-
zip((f"i{int(i)}" for i in range(len(inames))), inames),
325-
zip((f"o{int(i)}" for i in range(len(onames))), onames),
324+
zip((f"i{int(i)}" for i in range(len(inames))), inames, strict=True),
325+
zip((f"o{int(i)}" for i in range(len(onames))), onames, strict=True),
326326
),
327327
**sub,
328328
)

0 commit comments

Comments
 (0)