Skip to content

Commit cf9c5cb

Browse files
committed
Make zips strict in pytensor/compile
1 parent ed5e480 commit cf9c5cb

File tree

4 files changed

+57
-36
lines changed

4 files changed

+57
-36
lines changed

pytensor/compile/builders.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ def infer_shape(outs, inputs, input_shapes):
4545
# TODO: ShapeFeature should live elsewhere
4646
from pytensor.tensor.rewriting.shape import ShapeFeature
4747

48-
for inp, inp_shp in zip(inputs, input_shapes):
48+
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
4949
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
5050
assert len(inp_shp) == inp.type.ndim
5151

5252
shape_feature = ShapeFeature()
5353
shape_feature.on_attach(FunctionGraph([], []))
5454

5555
# Initialize shape_of with the input shapes
56-
for inp, inp_shp in zip(inputs, input_shapes):
56+
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
5757
shape_feature.set_shape(inp, inp_shp)
5858

5959
def local_traverse(out):
@@ -110,7 +110,9 @@ def construct_nominal_fgraph(
110110

111111
replacements = dict(
112112
zip(
113-
inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs
113+
inputs + implicit_shared_inputs,
114+
dummy_inputs + dummy_implicit_shared_inputs,
115+
strict=True,
114116
)
115117
)
116118

@@ -140,7 +142,7 @@ def construct_nominal_fgraph(
140142
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
141143
)
142144

143-
fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
145+
fgraph.replace_all(zip(local_inputs, nominal_local_inputs, strict=True))
144146

145147
for i, inp in enumerate(fgraph.inputs):
146148
nom_inp = nominal_local_inputs[i]
@@ -559,7 +561,9 @@ def lop_overrides(inps, grads):
559561
# compute non-overriding downsteam grads from upstreams grads
560562
# it's normal some input may be disconnected, thus the 'ignore'
561563
wrt = [
562-
lin for lin, gov in zip(inner_inputs, custom_input_grads) if gov is None
564+
lin
565+
for lin, gov in zip(inner_inputs, custom_input_grads, strict=True)
566+
if gov is None
563567
]
564568
default_input_grads = fn_grad(wrt=wrt) if wrt else []
565569
input_grads = self._combine_list_overrides(
@@ -650,7 +654,7 @@ def _build_and_cache_rop_op(self):
650654
f = [
651655
output
652656
for output, custom_output_grad in zip(
653-
inner_outputs, custom_output_grads
657+
inner_outputs, custom_output_grads, strict=True
654658
)
655659
if custom_output_grad is None
656660
]
@@ -730,18 +734,24 @@ def make_node(self, *inputs):
730734

731735
non_shared_inputs = [
732736
inp_t.filter_variable(inp)
733-
for inp, inp_t in zip(non_shared_inputs, self.input_types)
737+
for inp, inp_t in zip(non_shared_inputs, self.input_types, strict=True)
734738
]
735739

736740
new_shared_inputs = inputs[num_expected_inps:]
737-
inner_and_input_shareds = list(zip(self.shared_inputs, new_shared_inputs))
741+
inner_and_input_shareds = list(
742+
zip(self.shared_inputs, new_shared_inputs, strict=True)
743+
)
738744

739745
if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
740746
# The shared variables are not equal to the original shared
741747
# variables, so we construct a new `Op` that uses the new shared
742748
# variables instead.
743749
replace = dict(
744-
zip(self.inner_inputs[num_expected_inps:], new_shared_inputs)
750+
zip(
751+
self.inner_inputs[num_expected_inps:],
752+
new_shared_inputs,
753+
strict=True,
754+
)
745755
)
746756

747757
# If the new shared variables are inconsistent with the inner-graph,
@@ -808,7 +818,7 @@ def infer_shape(self, fgraph, node, shapes):
808818
# each shape call. PyTensor optimizer will clean this up later, but this
809819
# will make extra work for the optimizer.
810820

811-
repl = dict(zip(self.inner_inputs, node.inputs))
821+
repl = dict(zip(self.inner_inputs, node.inputs, strict=True))
812822
clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)]
813823
cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl)
814824
ret = []
@@ -850,7 +860,7 @@ def clone(self):
850860
def perform(self, node, inputs, outputs):
851861
variables = self.fn(*inputs)
852862
assert len(variables) == len(outputs)
853-
for output, variable in zip(outputs, variables):
863+
for output, variable in zip(outputs, variables, strict=True):
854864
output[0] = variable
855865

856866

@@ -866,7 +876,9 @@ def inline_ofg_expansion(fgraph, node):
866876
return False
867877
if not op.is_inline:
868878
return False
869-
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
879+
return clone_replace(
880+
op.inner_outputs, dict(zip(op.inner_inputs, node.inputs, strict=True))
881+
)
870882

871883

872884
# We want to run this before the first merge optimizer

pytensor/compile/debugmode.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ def _get_preallocated_maps(
865865
# except if broadcastable, or for dimensions above
866866
# config.DebugMode__check_preallocated_output_ndim
867867
buf_shape = []
868-
for s, b in zip(r_vals[r].shape, r.broadcastable):
868+
for s, b in zip(r_vals[r].shape, r.broadcastable, strict=True):
869869
if b or ((r.ndim - len(buf_shape)) > check_ndim):
870870
buf_shape.append(s)
871871
else:
@@ -944,7 +944,7 @@ def _get_preallocated_maps(
944944
r_shape_diff = shape_diff[: r.ndim]
945945
new_buf_shape = [
946946
max((s + sd), 0)
947-
for s, sd in zip(r_vals[r].shape, r_shape_diff)
947+
for s, sd in zip(r_vals[r].shape, r_shape_diff, strict=True)
948948
]
949949
new_buf = np.empty(new_buf_shape, dtype=r.type.dtype)
950950
new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
@@ -1580,7 +1580,7 @@ def f():
15801580
# try:
15811581
# compute the value of all variables
15821582
for i, (thunk_py, thunk_c, node) in enumerate(
1583-
zip(thunks_py, thunks_c, order)
1583+
zip(thunks_py, thunks_c, order, strict=True)
15841584
):
15851585
_logger.debug(f"{i} - starting node {i} {node}")
15861586

@@ -1868,7 +1868,7 @@ def thunk():
18681868
assert s[0] is None
18691869

18701870
# store our output variables to their respective storage lists
1871-
for output, storage in zip(fgraph.outputs, output_storage):
1871+
for output, storage in zip(fgraph.outputs, output_storage, strict=True):
18721872
storage[0] = r_vals[output]
18731873

18741874
# transfer all inputs back to their respective storage lists
@@ -1944,11 +1944,11 @@ def deco():
19441944
f,
19451945
[
19461946
Container(input, storage, readonly=False)
1947-
for input, storage in zip(fgraph.inputs, input_storage)
1947+
for input, storage in zip(fgraph.inputs, input_storage, strict=True)
19481948
],
19491949
[
19501950
Container(output, storage, readonly=True)
1951-
for output, storage in zip(fgraph.outputs, output_storage)
1951+
for output, storage in zip(fgraph.outputs, output_storage, strict=True)
19521952
],
19531953
thunks_py,
19541954
order,
@@ -2135,7 +2135,9 @@ def __init__(
21352135

21362136
no_borrow = [
21372137
output
2138-
for output, spec in zip(fgraph.outputs, outputs + additional_outputs)
2138+
for output, spec in zip(
2139+
fgraph.outputs, outputs + additional_outputs, strict=True
2140+
)
21392141
if not spec.borrow
21402142
]
21412143
if no_borrow:

pytensor/compile/function/pfunc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def construct_pfunc_ins_and_outs(
603603

604604
new_inputs = []
605605

606-
for i, iv in zip(inputs, input_variables):
606+
for i, iv in zip(inputs, input_variables, strict=True):
607607
new_i = copy(i)
608608
new_i.variable = iv
609609

@@ -637,13 +637,13 @@ def construct_pfunc_ins_and_outs(
637637
assert len(fgraph.inputs) == len(inputs)
638638
assert len(fgraph.outputs) == len(outputs)
639639

640-
for fg_inp, inp in zip(fgraph.inputs, inputs):
640+
for fg_inp, inp in zip(fgraph.inputs, inputs, strict=True):
641641
if fg_inp != getattr(inp, "variable", inp):
642642
raise ValueError(
643643
f"`fgraph`'s input does not match the provided input: {fg_inp}, {inp}"
644644
)
645645

646-
for fg_out, out in zip(fgraph.outputs, outputs):
646+
for fg_out, out in zip(fgraph.outputs, outputs, strict=True):
647647
if fg_out != getattr(out, "variable", out):
648648
raise ValueError(
649649
f"`fgraph`'s output does not match the provided output: {fg_out}, {out}"

pytensor/compile/function/types.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def std_fgraph(
246246
fgraph.attach_feature(
247247
Supervisor(
248248
input
249-
for spec, input in zip(input_specs, fgraph.inputs)
249+
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
250250
if not (
251251
spec.mutable
252252
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
@@ -427,7 +427,7 @@ def distribute(indices, cs, value):
427427
# this loop works by modifying the elements (as variable c) of
428428
# self.input_storage inplace.
429429
for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(
430-
zip(self.indices, defaults)
430+
zip(self.indices, defaults, strict=True)
431431
):
432432
if indices is None:
433433
# containers is being used as a stack. Here we pop off
@@ -656,7 +656,7 @@ def checkSV(sv_ori, sv_rpl):
656656
else:
657657
outs = list(map(SymbolicOutput, fg_cpy.outputs))
658658

659-
for out_ori, out_cpy in zip(maker.outputs, outs):
659+
for out_ori, out_cpy in zip(maker.outputs, outs, strict=True):
660660
out_cpy.borrow = out_ori.borrow
661661

662662
# swap SharedVariable
@@ -669,7 +669,7 @@ def checkSV(sv_ori, sv_rpl):
669669
raise ValueError(f"SharedVariable: {sv.name} not found")
670670

671671
# Swap SharedVariable in fgraph and In instances
672-
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
672+
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs, strict=True)):
673673
# Variables in maker.inputs are defined by user, therefore we
674674
# use them to make comparison and do the mapping.
675675
# Otherwise we don't touch them.
@@ -693,7 +693,7 @@ def checkSV(sv_ori, sv_rpl):
693693

694694
# Delete update if needed
695695
rev_update_mapping = {v: k for k, v in fg_cpy.update_mapping.items()}
696-
for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs)):
696+
for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs, strict=True)):
697697
inp.variable = in_var
698698
if not delete_updates and inp.update is not None:
699699
out_idx = rev_update_mapping[n]
@@ -759,7 +759,11 @@ def checkSV(sv_ori, sv_rpl):
759759
).create(input_storage, storage_map=new_storage_map)
760760

761761
for in_ori, in_cpy, ori, cpy in zip(
762-
maker.inputs, f_cpy.maker.inputs, self.input_storage, f_cpy.input_storage
762+
maker.inputs,
763+
f_cpy.maker.inputs,
764+
self.input_storage,
765+
f_cpy.input_storage,
766+
strict=True,
763767
):
764768
# Share immutable ShareVariable and constant input's storage
765769
swapped = swap is not None and in_ori.variable in swap
@@ -919,6 +923,7 @@ def restore_defaults():
919923
self.input_storage[k].storage[0]
920924
for k in args_share_memory[j]
921925
],
926+
strict=True,
922927
)
923928
if any(
924929
(
@@ -1011,7 +1016,7 @@ def restore_defaults():
10111016
if getattr(self.vm, "allow_gc", False):
10121017
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
10131018
for o_container, o_variable in zip(
1014-
self.output_storage, self.maker.fgraph.outputs
1019+
self.output_storage, self.maker.fgraph.outputs, strict=True
10151020
):
10161021
if o_variable.owner is not None:
10171022
# this node is the variable of computation
@@ -1023,7 +1028,7 @@ def restore_defaults():
10231028
if getattr(self.vm, "need_update_inputs", True):
10241029
# Update the inputs that have an update function
10251030
for input, storage in reversed(
1026-
list(zip(self.maker.expanded_inputs, self.input_storage))
1031+
list(zip(self.maker.expanded_inputs, self.input_storage, strict=True))
10271032
):
10281033
if input.update is not None:
10291034
storage.data = outputs.pop()
@@ -1058,7 +1063,7 @@ def restore_defaults():
10581063
assert len(self.output_keys) == len(outputs)
10591064

10601065
if output_subset is None:
1061-
return dict(zip(self.output_keys, outputs))
1066+
return dict(zip(self.output_keys, outputs, strict=True))
10621067
else:
10631068
return {
10641069
self.output_keys[index]: outputs[index]
@@ -1114,7 +1119,7 @@ def _pickle_Function(f):
11141119
input_storage = []
11151120

11161121
for (input, indices, inputs), (required, refeed, default) in zip(
1117-
f.indices, f.defaults
1122+
f.indices, f.defaults, strict=True
11181123
):
11191124
input_storage.append(ins[0])
11201125
del ins[0]
@@ -1156,7 +1161,7 @@ def _constructor_Function(maker, input_storage, inputs_data, trust_input=False):
11561161

11571162
f = maker.create(input_storage)
11581163
assert len(f.input_storage) == len(inputs_data)
1159-
for container, x in zip(f.input_storage, inputs_data):
1164+
for container, x in zip(f.input_storage, inputs_data, strict=True):
11601165
assert (
11611166
(container.data is x)
11621167
or (isinstance(x, np.ndarray) and (container.data == x).all())
@@ -1190,7 +1195,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
11901195
reason = "insert_deepcopy"
11911196
updated_fgraph_inputs = {
11921197
fgraph_i
1193-
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs)
1198+
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs, strict=True)
11941199
if getattr(i, "update", False)
11951200
}
11961201

@@ -1528,7 +1533,9 @@ def __init__(
15281533
# return the internal storage pointer.
15291534
no_borrow = [
15301535
output
1531-
for output, spec in zip(fgraph.outputs, outputs + found_updates)
1536+
for output, spec in zip(
1537+
fgraph.outputs, outputs + found_updates, strict=True
1538+
)
15321539
if not spec.borrow
15331540
]
15341541

@@ -1595,7 +1602,7 @@ def create(self, input_storage=None, storage_map=None):
15951602
# defaults lists.
15961603
assert len(self.indices) == len(input_storage)
15971604
for i, ((input, indices, subinputs), input_storage_i) in enumerate(
1598-
zip(self.indices, input_storage)
1605+
zip(self.indices, input_storage, strict=True)
15991606
):
16001607
# Replace any default value given as a variable by its
16011608
# container. Note that this makes sense only in the

0 commit comments

Comments
 (0)