Skip to content

Commit 3f254f7

Browse files
committed
Add a strict argument to all zips
1 parent 05d376f commit 3f254f7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

100 files changed

+750
-466
lines changed

pytensor/compile/builders.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ def infer_shape(outs, inputs, input_shapes):
4545
# TODO: ShapeFeature should live elsewhere
4646
from pytensor.tensor.rewriting.shape import ShapeFeature
4747

48-
for inp, inp_shp in zip(inputs, input_shapes):
48+
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
4949
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
5050
assert len(inp_shp) == inp.type.ndim
5151

5252
shape_feature = ShapeFeature()
5353
shape_feature.on_attach(FunctionGraph([], []))
5454

5555
# Initialize shape_of with the input shapes
56-
for inp, inp_shp in zip(inputs, input_shapes):
56+
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
5757
shape_feature.set_shape(inp, inp_shp)
5858

5959
def local_traverse(out):
@@ -110,7 +110,9 @@ def construct_nominal_fgraph(
110110

111111
replacements = dict(
112112
zip(
113-
inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs
113+
inputs + implicit_shared_inputs,
114+
dummy_inputs + dummy_implicit_shared_inputs,
115+
strict=True,
114116
)
115117
)
116118

@@ -140,7 +142,7 @@ def construct_nominal_fgraph(
140142
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
141143
)
142144

143-
fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
145+
fgraph.replace_all(zip(local_inputs, nominal_local_inputs, strict=True))
144146

145147
for i, inp in enumerate(fgraph.inputs):
146148
nom_inp = nominal_local_inputs[i]
@@ -559,7 +561,9 @@ def lop_overrides(inps, grads):
559561
# compute non-overriding downsteam grads from upstreams grads
560562
# it's normal some input may be disconnected, thus the 'ignore'
561563
wrt = [
562-
lin for lin, gov in zip(inner_inputs, custom_input_grads) if gov is None
564+
lin
565+
for lin, gov in zip(inner_inputs, custom_input_grads, strict=True)
566+
if gov is None
563567
]
564568
default_input_grads = fn_grad(wrt=wrt) if wrt else []
565569
input_grads = self._combine_list_overrides(
@@ -650,7 +654,7 @@ def _build_and_cache_rop_op(self):
650654
f = [
651655
output
652656
for output, custom_output_grad in zip(
653-
inner_outputs, custom_output_grads
657+
inner_outputs, custom_output_grads, strict=True
654658
)
655659
if custom_output_grad is None
656660
]
@@ -730,18 +734,24 @@ def make_node(self, *inputs):
730734

731735
non_shared_inputs = [
732736
inp_t.filter_variable(inp)
733-
for inp, inp_t in zip(non_shared_inputs, self.input_types)
737+
for inp, inp_t in zip(non_shared_inputs, self.input_types, strict=True)
734738
]
735739

736740
new_shared_inputs = inputs[num_expected_inps:]
737-
inner_and_input_shareds = list(zip(self.shared_inputs, new_shared_inputs))
741+
inner_and_input_shareds = list(
742+
zip(self.shared_inputs, new_shared_inputs, strict=True)
743+
)
738744

739745
if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
740746
# The shared variables are not equal to the original shared
741747
# variables, so we construct a new `Op` that uses the new shared
742748
# variables instead.
743749
replace = dict(
744-
zip(self.inner_inputs[num_expected_inps:], new_shared_inputs)
750+
zip(
751+
self.inner_inputs[num_expected_inps:],
752+
new_shared_inputs,
753+
strict=True,
754+
)
745755
)
746756

747757
# If the new shared variables are inconsistent with the inner-graph,
@@ -808,7 +818,7 @@ def infer_shape(self, fgraph, node, shapes):
808818
# each shape call. PyTensor optimizer will clean this up later, but this
809819
# will make extra work for the optimizer.
810820

811-
repl = dict(zip(self.inner_inputs, node.inputs))
821+
repl = dict(zip(self.inner_inputs, node.inputs, strict=True))
812822
clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)]
813823
cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl)
814824
ret = []
@@ -850,7 +860,7 @@ def clone(self):
850860
def perform(self, node, inputs, outputs):
851861
variables = self.fn(*inputs)
852862
assert len(variables) == len(outputs)
853-
for output, variable in zip(outputs, variables):
863+
for output, variable in zip(outputs, variables, strict=True):
854864
output[0] = variable
855865

856866

@@ -866,7 +876,9 @@ def inline_ofg_expansion(fgraph, node):
866876
return False
867877
if not op.is_inline:
868878
return False
869-
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
879+
return clone_replace(
880+
op.inner_outputs, dict(zip(op.inner_inputs, node.inputs, strict=True))
881+
)
870882

871883

872884
# We want to run this before the first merge optimizer

pytensor/compile/debugmode.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ def _get_preallocated_maps(
865865
# except if broadcastable, or for dimensions above
866866
# config.DebugMode__check_preallocated_output_ndim
867867
buf_shape = []
868-
for s, b in zip(r_vals[r].shape, r.broadcastable):
868+
for s, b in zip(r_vals[r].shape, r.broadcastable, strict=True):
869869
if b or ((r.ndim - len(buf_shape)) > check_ndim):
870870
buf_shape.append(s)
871871
else:
@@ -943,7 +943,7 @@ def _get_preallocated_maps(
943943
r_shape_diff = shape_diff[: r.ndim]
944944
new_buf_shape = [
945945
max((s + sd), 0)
946-
for s, sd in zip(r_vals[r].shape, r_shape_diff)
946+
for s, sd in zip(r_vals[r].shape, r_shape_diff, strict=True)
947947
]
948948
new_buf = np.empty(new_buf_shape, dtype=r.type.dtype)
949949
new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
@@ -1579,7 +1579,7 @@ def f():
15791579
# try:
15801580
# compute the value of all variables
15811581
for i, (thunk_py, thunk_c, node) in enumerate(
1582-
zip(thunks_py, thunks_c, order)
1582+
zip(thunks_py, thunks_c, order, strict=True)
15831583
):
15841584
_logger.debug(f"{i} - starting node {i} {node}")
15851585

@@ -1867,7 +1867,7 @@ def thunk():
18671867
assert s[0] is None
18681868

18691869
# store our output variables to their respective storage lists
1870-
for output, storage in zip(fgraph.outputs, output_storage):
1870+
for output, storage in zip(fgraph.outputs, output_storage, strict=True):
18711871
storage[0] = r_vals[output]
18721872

18731873
# transfer all inputs back to their respective storage lists
@@ -1943,11 +1943,11 @@ def deco():
19431943
f,
19441944
[
19451945
Container(input, storage, readonly=False)
1946-
for input, storage in zip(fgraph.inputs, input_storage)
1946+
for input, storage in zip(fgraph.inputs, input_storage, strict=True)
19471947
],
19481948
[
19491949
Container(output, storage, readonly=True)
1950-
for output, storage in zip(fgraph.outputs, output_storage)
1950+
for output, storage in zip(fgraph.outputs, output_storage, strict=True)
19511951
],
19521952
thunks_py,
19531953
order,
@@ -2134,7 +2134,9 @@ def __init__(
21342134

21352135
no_borrow = [
21362136
output
2137-
for output, spec in zip(fgraph.outputs, outputs + additional_outputs)
2137+
for output, spec in zip(
2138+
fgraph.outputs, outputs + additional_outputs, strict=True
2139+
)
21382140
if not spec.borrow
21392141
]
21402142
if no_borrow:

pytensor/compile/function/pfunc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def construct_pfunc_ins_and_outs(
603603

604604
new_inputs = []
605605

606-
for i, iv in zip(inputs, input_variables):
606+
for i, iv in zip(inputs, input_variables, strict=True):
607607
new_i = copy(i)
608608
new_i.variable = iv
609609

@@ -637,13 +637,13 @@ def construct_pfunc_ins_and_outs(
637637
assert len(fgraph.inputs) == len(inputs)
638638
assert len(fgraph.outputs) == len(outputs)
639639

640-
for fg_inp, inp in zip(fgraph.inputs, inputs):
640+
for fg_inp, inp in zip(fgraph.inputs, inputs, strict=True):
641641
if fg_inp != getattr(inp, "variable", inp):
642642
raise ValueError(
643643
f"`fgraph`'s input does not match the provided input: {fg_inp}, {inp}"
644644
)
645645

646-
for fg_out, out in zip(fgraph.outputs, outputs):
646+
for fg_out, out in zip(fgraph.outputs, outputs, strict=True):
647647
if fg_out != getattr(out, "variable", out):
648648
raise ValueError(
649649
f"`fgraph`'s output does not match the provided output: {fg_out}, {out}"

pytensor/compile/function/types.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def std_fgraph(
243243
fgraph.attach_feature(
244244
Supervisor(
245245
input
246-
for spec, input in zip(input_specs, fgraph.inputs)
246+
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
247247
if not (
248248
spec.mutable
249249
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
@@ -424,7 +424,7 @@ def distribute(indices, cs, value):
424424
# this loop works by modifying the elements (as variable c) of
425425
# self.input_storage inplace.
426426
for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(
427-
zip(self.indices, defaults)
427+
zip(self.indices, defaults, strict=True)
428428
):
429429
if indices is None:
430430
# containers is being used as a stack. Here we pop off
@@ -653,7 +653,7 @@ def checkSV(sv_ori, sv_rpl):
653653
else:
654654
outs = list(map(SymbolicOutput, fg_cpy.outputs))
655655

656-
for out_ori, out_cpy in zip(maker.outputs, outs):
656+
for out_ori, out_cpy in zip(maker.outputs, outs, strict=False):
657657
out_cpy.borrow = out_ori.borrow
658658

659659
# swap SharedVariable
@@ -666,7 +666,7 @@ def checkSV(sv_ori, sv_rpl):
666666
raise ValueError(f"SharedVariable: {sv.name} not found")
667667

668668
# Swap SharedVariable in fgraph and In instances
669-
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
669+
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs, strict=True)):
670670
# Variables in maker.inputs are defined by user, therefore we
671671
# use them to make comparison and do the mapping.
672672
# Otherwise we don't touch them.
@@ -690,7 +690,7 @@ def checkSV(sv_ori, sv_rpl):
690690

691691
# Delete update if needed
692692
rev_update_mapping = {v: k for k, v in fg_cpy.update_mapping.items()}
693-
for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs)):
693+
for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs, strict=True)):
694694
inp.variable = in_var
695695
if not delete_updates and inp.update is not None:
696696
out_idx = rev_update_mapping[n]
@@ -750,7 +750,11 @@ def checkSV(sv_ori, sv_rpl):
750750
).create(input_storage, storage_map=new_storage_map)
751751

752752
for in_ori, in_cpy, ori, cpy in zip(
753-
maker.inputs, f_cpy.maker.inputs, self.input_storage, f_cpy.input_storage
753+
maker.inputs,
754+
f_cpy.maker.inputs,
755+
self.input_storage,
756+
f_cpy.input_storage,
757+
strict=True,
754758
):
755759
# Share immutable ShareVariable and constant input's storage
756760
swapped = swap is not None and in_ori.variable in swap
@@ -910,6 +914,7 @@ def restore_defaults():
910914
self.input_storage[k].storage[0]
911915
for k in args_share_memory[j]
912916
],
917+
strict=True,
913918
)
914919
if any(
915920
(
@@ -1002,7 +1007,7 @@ def restore_defaults():
10021007
if getattr(self.vm, "allow_gc", False):
10031008
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
10041009
for o_container, o_variable in zip(
1005-
self.output_storage, self.maker.fgraph.outputs
1010+
self.output_storage, self.maker.fgraph.outputs, strict=True
10061011
):
10071012
if o_variable.owner is not None:
10081013
# this node is the variable of computation
@@ -1014,7 +1019,7 @@ def restore_defaults():
10141019
if getattr(self.vm, "need_update_inputs", True):
10151020
# Update the inputs that have an update function
10161021
for input, storage in reversed(
1017-
list(zip(self.maker.expanded_inputs, self.input_storage))
1022+
list(zip(self.maker.expanded_inputs, self.input_storage, strict=True))
10181023
):
10191024
if input.update is not None:
10201025
storage.data = outputs.pop()
@@ -1049,7 +1054,7 @@ def restore_defaults():
10491054
assert len(self.output_keys) == len(outputs)
10501055

10511056
if output_subset is None:
1052-
return dict(zip(self.output_keys, outputs))
1057+
return dict(zip(self.output_keys, outputs, strict=True))
10531058
else:
10541059
return {
10551060
self.output_keys[index]: outputs[index]
@@ -1117,7 +1122,7 @@ def _pickle_Function(f):
11171122
input_storage = []
11181123

11191124
for (input, indices, inputs), (required, refeed, default) in zip(
1120-
f.indices, f.defaults
1125+
f.indices, f.defaults, strict=True
11211126
):
11221127
input_storage.append(ins[0])
11231128
del ins[0]
@@ -1159,7 +1164,7 @@ def _constructor_Function(maker, input_storage, inputs_data, trust_input=False):
11591164

11601165
f = maker.create(input_storage)
11611166
assert len(f.input_storage) == len(inputs_data)
1162-
for container, x in zip(f.input_storage, inputs_data):
1167+
for container, x in zip(f.input_storage, inputs_data, strict=True):
11631168
assert (
11641169
(container.data is x)
11651170
or (isinstance(x, np.ndarray) and (container.data == x).all())
@@ -1193,7 +1198,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
11931198
reason = "insert_deepcopy"
11941199
updated_fgraph_inputs = {
11951200
fgraph_i
1196-
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs)
1201+
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs, strict=True)
11971202
if getattr(i, "update", False)
11981203
}
11991204

@@ -1531,7 +1536,9 @@ def __init__(
15311536
# return the internal storage pointer.
15321537
no_borrow = [
15331538
output
1534-
for output, spec in zip(fgraph.outputs, outputs + found_updates)
1539+
for output, spec in zip(
1540+
fgraph.outputs, outputs + found_updates, strict=True
1541+
)
15351542
if not spec.borrow
15361543
]
15371544

@@ -1598,7 +1605,7 @@ def create(self, input_storage=None, storage_map=None):
15981605
# defaults lists.
15991606
assert len(self.indices) == len(input_storage)
16001607
for i, ((input, indices, subinputs), input_storage_i) in enumerate(
1601-
zip(self.indices, input_storage)
1608+
zip(self.indices, input_storage, strict=True)
16021609
):
16031610
# Replace any default value given as a variable by its
16041611
# container. Note that this makes sense only in the

pytensor/d3viz/formatting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,14 @@ def format_map(m):
244244
ext_inputs = [self.__node_id(x) for x in node.inputs]
245245
int_inputs = [gf.__node_id(x) for x in node.op.inner_inputs]
246246
assert len(ext_inputs) == len(int_inputs)
247-
h = format_map(zip(ext_inputs, int_inputs))
247+
h = format_map(zip(ext_inputs, int_inputs, strict=True))
248248
pd_node.get_attributes()["subg_map_inputs"] = h
249249

250250
# Outputs mapping
251251
ext_outputs = [self.__node_id(x) for x in node.outputs]
252252
int_outputs = [gf.__node_id(x) for x in node.op.inner_outputs]
253253
assert len(ext_outputs) == len(int_outputs)
254-
h = format_map(zip(int_outputs, ext_outputs))
254+
h = format_map(zip(int_outputs, ext_outputs, strict=True))
255255
pd_node.get_attributes()["subg_map_outputs"] = h
256256

257257
return graph

0 commit comments

Comments
 (0)