Skip to content

Commit ba5ab92

Browse files
committed
Add a strict argument to all zips
1 parent 10f285a commit ba5ab92

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

100 files changed

+749
-467
lines changed

pytensor/compile/builders.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ def infer_shape(outs, inputs, input_shapes):
4545
# TODO: ShapeFeature should live elsewhere
4646
from pytensor.tensor.rewriting.shape import ShapeFeature
4747

48-
for inp, inp_shp in zip(inputs, input_shapes):
48+
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
4949
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
5050
assert len(inp_shp) == inp.type.ndim
5151

5252
shape_feature = ShapeFeature()
5353
shape_feature.on_attach(FunctionGraph([], []))
5454

5555
# Initialize shape_of with the input shapes
56-
for inp, inp_shp in zip(inputs, input_shapes):
56+
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
5757
shape_feature.set_shape(inp, inp_shp)
5858

5959
def local_traverse(out):
@@ -110,7 +110,9 @@ def construct_nominal_fgraph(
110110

111111
replacements = dict(
112112
zip(
113-
inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs
113+
inputs + implicit_shared_inputs,
114+
dummy_inputs + dummy_implicit_shared_inputs,
115+
strict=True,
114116
)
115117
)
116118

@@ -140,7 +142,7 @@ def construct_nominal_fgraph(
140142
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
141143
)
142144

143-
fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
145+
fgraph.replace_all(zip(local_inputs, nominal_local_inputs, strict=True))
144146

145147
for i, inp in enumerate(fgraph.inputs):
146148
nom_inp = nominal_local_inputs[i]
@@ -559,7 +561,9 @@ def lop_overrides(inps, grads):
559561
# compute non-overriding downsteam grads from upstreams grads
560562
# it's normal some input may be disconnected, thus the 'ignore'
561563
wrt = [
562-
lin for lin, gov in zip(inner_inputs, custom_input_grads) if gov is None
564+
lin
565+
for lin, gov in zip(inner_inputs, custom_input_grads, strict=True)
566+
if gov is None
563567
]
564568
default_input_grads = fn_grad(wrt=wrt) if wrt else []
565569
input_grads = self._combine_list_overrides(
@@ -650,7 +654,7 @@ def _build_and_cache_rop_op(self):
650654
f = [
651655
output
652656
for output, custom_output_grad in zip(
653-
inner_outputs, custom_output_grads
657+
inner_outputs, custom_output_grads, strict=True
654658
)
655659
if custom_output_grad is None
656660
]
@@ -730,18 +734,24 @@ def make_node(self, *inputs):
730734

731735
non_shared_inputs = [
732736
inp_t.filter_variable(inp)
733-
for inp, inp_t in zip(non_shared_inputs, self.input_types)
737+
for inp, inp_t in zip(non_shared_inputs, self.input_types, strict=True)
734738
]
735739

736740
new_shared_inputs = inputs[num_expected_inps:]
737-
inner_and_input_shareds = list(zip(self.shared_inputs, new_shared_inputs))
741+
inner_and_input_shareds = list(
742+
zip(self.shared_inputs, new_shared_inputs, strict=True)
743+
)
738744

739745
if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
740746
# The shared variables are not equal to the original shared
741747
# variables, so we construct a new `Op` that uses the new shared
742748
# variables instead.
743749
replace = dict(
744-
zip(self.inner_inputs[num_expected_inps:], new_shared_inputs)
750+
zip(
751+
self.inner_inputs[num_expected_inps:],
752+
new_shared_inputs,
753+
strict=True,
754+
)
745755
)
746756

747757
# If the new shared variables are inconsistent with the inner-graph,
@@ -808,7 +818,7 @@ def infer_shape(self, fgraph, node, shapes):
808818
# each shape call. PyTensor optimizer will clean this up later, but this
809819
# will make extra work for the optimizer.
810820

811-
repl = dict(zip(self.inner_inputs, node.inputs))
821+
repl = dict(zip(self.inner_inputs, node.inputs, strict=True))
812822
clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)]
813823
cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl)
814824
ret = []
@@ -850,7 +860,7 @@ def clone(self):
850860
def perform(self, node, inputs, outputs):
851861
variables = self.fn(*inputs)
852862
assert len(variables) == len(outputs)
853-
for output, variable in zip(outputs, variables):
863+
for output, variable in zip(outputs, variables, strict=True):
854864
output[0] = variable
855865

856866

@@ -866,7 +876,9 @@ def inline_ofg_expansion(fgraph, node):
866876
return False
867877
if not op.is_inline:
868878
return False
869-
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
879+
return clone_replace(
880+
op.inner_outputs, dict(zip(op.inner_inputs, node.inputs, strict=True))
881+
)
870882

871883

872884
# We want to run this before the first merge optimizer

pytensor/compile/debugmode.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,7 @@ def _get_preallocated_maps(
868868
# except if broadcastable, or for dimensions above
869869
# config.DebugMode__check_preallocated_output_ndim
870870
buf_shape = []
871-
for s, b in zip(r_vals[r].shape, r.broadcastable):
871+
for s, b in zip(r_vals[r].shape, r.broadcastable, strict=True):
872872
if b or ((r.ndim - len(buf_shape)) > check_ndim):
873873
buf_shape.append(s)
874874
else:
@@ -946,7 +946,7 @@ def _get_preallocated_maps(
946946
r_shape_diff = shape_diff[: r.ndim]
947947
new_buf_shape = [
948948
max((s + sd), 0)
949-
for s, sd in zip(r_vals[r].shape, r_shape_diff)
949+
for s, sd in zip(r_vals[r].shape, r_shape_diff, strict=True)
950950
]
951951
new_buf = np.empty(new_buf_shape, dtype=r.type.dtype)
952952
new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
@@ -1578,7 +1578,7 @@ def f():
15781578
# try:
15791579
# compute the value of all variables
15801580
for i, (thunk_py, thunk_c, node) in enumerate(
1581-
zip(thunks_py, thunks_c, order)
1581+
zip(thunks_py, thunks_c, order, strict=True)
15821582
):
15831583
_logger.debug(f"{i} - starting node {i} {node}")
15841584

@@ -1866,7 +1866,7 @@ def thunk():
18661866
assert s[0] is None
18671867

18681868
# store our output variables to their respective storage lists
1869-
for output, storage in zip(fgraph.outputs, output_storage):
1869+
for output, storage in zip(fgraph.outputs, output_storage, strict=True):
18701870
storage[0] = r_vals[output]
18711871

18721872
# transfer all inputs back to their respective storage lists
@@ -1942,11 +1942,11 @@ def deco():
19421942
f,
19431943
[
19441944
Container(input, storage, readonly=False)
1945-
for input, storage in zip(fgraph.inputs, input_storage)
1945+
for input, storage in zip(fgraph.inputs, input_storage, strict=True)
19461946
],
19471947
[
19481948
Container(output, storage, readonly=True)
1949-
for output, storage in zip(fgraph.outputs, output_storage)
1949+
for output, storage in zip(fgraph.outputs, output_storage, strict=True)
19501950
],
19511951
thunks_py,
19521952
order,
@@ -2133,7 +2133,9 @@ def __init__(
21332133

21342134
no_borrow = [
21352135
output
2136-
for output, spec in zip(fgraph.outputs, outputs + additional_outputs)
2136+
for output, spec in zip(
2137+
fgraph.outputs, outputs + additional_outputs, strict=True
2138+
)
21372139
if not spec.borrow
21382140
]
21392141
if no_borrow:

pytensor/compile/function/pfunc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def construct_pfunc_ins_and_outs(
603603

604604
new_inputs = []
605605

606-
for i, iv in zip(inputs, input_variables):
606+
for i, iv in zip(inputs, input_variables, strict=True):
607607
new_i = copy(i)
608608
new_i.variable = iv
609609

@@ -637,13 +637,13 @@ def construct_pfunc_ins_and_outs(
637637
assert len(fgraph.inputs) == len(inputs)
638638
assert len(fgraph.outputs) == len(outputs)
639639

640-
for fg_inp, inp in zip(fgraph.inputs, inputs):
640+
for fg_inp, inp in zip(fgraph.inputs, inputs, strict=True):
641641
if fg_inp != getattr(inp, "variable", inp):
642642
raise ValueError(
643643
f"`fgraph`'s input does not match the provided input: {fg_inp}, {inp}"
644644
)
645645

646-
for fg_out, out in zip(fgraph.outputs, outputs):
646+
for fg_out, out in zip(fgraph.outputs, outputs, strict=True):
647647
if fg_out != getattr(out, "variable", out):
648648
raise ValueError(
649649
f"`fgraph`'s output does not match the provided output: {fg_out}, {out}"

pytensor/compile/function/types.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def std_fgraph(
241241
fgraph.attach_feature(
242242
Supervisor(
243243
input
244-
for spec, input in zip(input_specs, fgraph.inputs)
244+
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
245245
if not (
246246
spec.mutable
247247
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
@@ -422,7 +422,7 @@ def distribute(indices, cs, value):
422422
# this loop works by modifying the elements (as variable c) of
423423
# self.input_storage inplace.
424424
for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(
425-
zip(self.indices, defaults)
425+
zip(self.indices, defaults, strict=True)
426426
):
427427
if indices is None:
428428
# containers is being used as a stack. Here we pop off
@@ -651,7 +651,7 @@ def checkSV(sv_ori, sv_rpl):
651651
else:
652652
outs = list(map(SymbolicOutput, fg_cpy.outputs))
653653

654-
for out_ori, out_cpy in zip(maker.outputs, outs):
654+
for out_ori, out_cpy in zip(maker.outputs, outs, strict=False):
655655
out_cpy.borrow = out_ori.borrow
656656

657657
# swap SharedVariable
@@ -664,7 +664,7 @@ def checkSV(sv_ori, sv_rpl):
664664
raise ValueError(f"SharedVariable: {sv.name} not found")
665665

666666
# Swap SharedVariable in fgraph and In instances
667-
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
667+
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs, strict=True)):
668668
# Variables in maker.inputs are defined by user, therefore we
669669
# use them to make comparison and do the mapping.
670670
# Otherwise we don't touch them.
@@ -688,7 +688,7 @@ def checkSV(sv_ori, sv_rpl):
688688

689689
# Delete update if needed
690690
rev_update_mapping = {v: k for k, v in fg_cpy.update_mapping.items()}
691-
for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs)):
691+
for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs, strict=True)):
692692
inp.variable = in_var
693693
if not delete_updates and inp.update is not None:
694694
out_idx = rev_update_mapping[n]
@@ -748,7 +748,11 @@ def checkSV(sv_ori, sv_rpl):
748748
).create(input_storage, storage_map=new_storage_map)
749749

750750
for in_ori, in_cpy, ori, cpy in zip(
751-
maker.inputs, f_cpy.maker.inputs, self.input_storage, f_cpy.input_storage
751+
maker.inputs,
752+
f_cpy.maker.inputs,
753+
self.input_storage,
754+
f_cpy.input_storage,
755+
strict=True,
752756
):
753757
# Share immutable ShareVariable and constant input's storage
754758
swapped = swap is not None and in_ori.variable in swap
@@ -908,6 +912,7 @@ def restore_defaults():
908912
self.input_storage[k].storage[0]
909913
for k in args_share_memory[j]
910914
],
915+
strict=True,
911916
)
912917
if any(
913918
(
@@ -1000,7 +1005,7 @@ def restore_defaults():
10001005
if getattr(self.vm, "allow_gc", False):
10011006
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
10021007
for o_container, o_variable in zip(
1003-
self.output_storage, self.maker.fgraph.outputs
1008+
self.output_storage, self.maker.fgraph.outputs, strict=True
10041009
):
10051010
if o_variable.owner is not None:
10061011
# this node is the variable of computation
@@ -1012,7 +1017,7 @@ def restore_defaults():
10121017
if getattr(self.vm, "need_update_inputs", True):
10131018
# Update the inputs that have an update function
10141019
for input, storage in reversed(
1015-
list(zip(self.maker.expanded_inputs, self.input_storage))
1020+
list(zip(self.maker.expanded_inputs, self.input_storage, strict=True))
10161021
):
10171022
if input.update is not None:
10181023
storage.data = outputs.pop()
@@ -1047,7 +1052,7 @@ def restore_defaults():
10471052
assert len(self.output_keys) == len(outputs)
10481053

10491054
if output_subset is None:
1050-
return dict(zip(self.output_keys, outputs))
1055+
return dict(zip(self.output_keys, outputs, strict=True))
10511056
else:
10521057
return {
10531058
self.output_keys[index]: outputs[index]
@@ -1115,7 +1120,7 @@ def _pickle_Function(f):
11151120
input_storage = []
11161121

11171122
for (input, indices, inputs), (required, refeed, default) in zip(
1118-
f.indices, f.defaults
1123+
f.indices, f.defaults, strict=True
11191124
):
11201125
input_storage.append(ins[0])
11211126
del ins[0]
@@ -1157,7 +1162,7 @@ def _constructor_Function(maker, input_storage, inputs_data, trust_input=False):
11571162

11581163
f = maker.create(input_storage)
11591164
assert len(f.input_storage) == len(inputs_data)
1160-
for container, x in zip(f.input_storage, inputs_data):
1165+
for container, x in zip(f.input_storage, inputs_data, strict=True):
11611166
assert (
11621167
(container.data is x)
11631168
or (isinstance(x, np.ndarray) and (container.data == x).all())
@@ -1191,7 +1196,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
11911196
reason = "insert_deepcopy"
11921197
updated_fgraph_inputs = {
11931198
fgraph_i
1194-
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs)
1199+
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs, strict=True)
11951200
if getattr(i, "update", False)
11961201
}
11971202

@@ -1528,7 +1533,9 @@ def __init__(
15281533
# return the internal storage pointer.
15291534
no_borrow = [
15301535
output
1531-
for output, spec in zip(fgraph.outputs, outputs + found_updates)
1536+
for output, spec in zip(
1537+
fgraph.outputs, outputs + found_updates, strict=True
1538+
)
15321539
if not spec.borrow
15331540
]
15341541

@@ -1595,7 +1602,7 @@ def create(self, input_storage=None, storage_map=None):
15951602
# defaults lists.
15961603
assert len(self.indices) == len(input_storage)
15971604
for i, ((input, indices, subinputs), input_storage_i) in enumerate(
1598-
zip(self.indices, input_storage)
1605+
zip(self.indices, input_storage, strict=True)
15991606
):
16001607
# Replace any default value given as a variable by its
16011608
# container. Note that this makes sense only in the

pytensor/d3viz/formatting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,14 @@ def format_map(m):
244244
ext_inputs = [self.__node_id(x) for x in node.inputs]
245245
int_inputs = [gf.__node_id(x) for x in node.op.inner_inputs]
246246
assert len(ext_inputs) == len(int_inputs)
247-
h = format_map(zip(ext_inputs, int_inputs))
247+
h = format_map(zip(ext_inputs, int_inputs, strict=True))
248248
pd_node.get_attributes()["subg_map_inputs"] = h
249249

250250
# Outputs mapping
251251
ext_outputs = [self.__node_id(x) for x in node.outputs]
252252
int_outputs = [gf.__node_id(x) for x in node.op.inner_outputs]
253253
assert len(ext_outputs) == len(int_outputs)
254-
h = format_map(zip(int_outputs, ext_outputs))
254+
h = format_map(zip(int_outputs, ext_outputs, strict=True))
255255
pd_node.get_attributes()["subg_map_outputs"] = h
256256

257257
return graph

0 commit comments

Comments
 (0)