Skip to content

Commit 8983796

Browse files
committed
Make zips strict in pytensor/link
1 parent a65c3df commit 8983796

File tree

3 files changed

+26
-20
lines changed

3 files changed

+26
-20
lines changed

pytensor/link/basic.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -385,11 +385,11 @@ def make_all(
385385
f,
386386
[
387387
Container(input, storage)
388-
for input, storage in zip(fgraph.inputs, input_storage)
388+
for input, storage in zip(fgraph.inputs, input_storage, strict=True)
389389
],
390390
[
391391
Container(output, storage, readonly=True)
392-
for output, storage in zip(fgraph.outputs, output_storage)
392+
for output, storage in zip(fgraph.outputs, output_storage, strict=True)
393393
],
394394
thunks,
395395
order,
@@ -509,7 +509,9 @@ def make_thunk(self, **kwargs):
509509
kwargs.pop("input_storage", None)
510510
make_all += [x.make_all(**kwargs) for x in self.linkers[1:]]
511511

512-
fns, input_lists, output_lists, thunk_lists, order_lists = zip(*make_all)
512+
fns, input_lists, output_lists, thunk_lists, order_lists = zip(
513+
*make_all, strict=True
514+
)
513515

514516
order_list0 = order_lists[0]
515517
for order_list in order_lists[1:]:
@@ -521,11 +523,11 @@ def make_thunk(self, **kwargs):
521523
inputs0 = input_lists[0]
522524
outputs0 = output_lists[0]
523525

524-
thunk_groups = list(zip(*thunk_lists))
525-
order = [x[0] for x in zip(*order_lists)]
526+
thunk_groups = list(zip(*thunk_lists, strict=True))
527+
order = [x[0] for x in zip(*order_lists, strict=True)]
526528

527529
to_reset = []
528-
for thunks, node in zip(thunk_groups, order):
530+
for thunks, node in zip(thunk_groups, order, strict=True):
529531
for j, output in enumerate(node.outputs):
530532
if output in no_recycling:
531533
for thunk in thunks:
@@ -536,12 +538,12 @@ def make_thunk(self, **kwargs):
536538

537539
def f():
538540
for inputs in input_lists[1:]:
539-
for input1, input2 in zip(inputs0, inputs):
541+
for input1, input2 in zip(inputs0, inputs, strict=True):
540542
input2.storage[0] = copy(input1.storage[0])
541543
for x in to_reset:
542544
x[0] = None
543545
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
544-
for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
546+
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=True)):
545547
try:
546548
wrapper(self.fgraph, i, node, *thunks)
547549
except Exception:
@@ -663,7 +665,9 @@ def thunk(
663665
):
664666
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
665667

666-
for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
668+
for o_var, o_storage, o_val in zip(
669+
fgraph.outputs, thunk_outputs, outputs, strict=True
670+
):
667671
compute_map[o_var][0] = True
668672
o_storage[0] = self.output_filter(o_var, o_val)
669673
return outputs
@@ -731,11 +735,11 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
731735
fn,
732736
[
733737
Container(input, storage)
734-
for input, storage in zip(fgraph.inputs, input_storage)
738+
for input, storage in zip(fgraph.inputs, input_storage, strict=True)
735739
],
736740
[
737741
Container(output, storage, readonly=True)
738-
for output, storage in zip(fgraph.outputs, output_storage)
742+
for output, storage in zip(fgraph.outputs, output_storage, strict=True)
739743
],
740744
thunks,
741745
nodes,

pytensor/link/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def map_storage(
8888
assert len(fgraph.inputs) == len(input_storage)
8989

9090
# add input storage into storage_map
91-
for r, storage in zip(fgraph.inputs, input_storage):
91+
for r, storage in zip(fgraph.inputs, input_storage, strict=True):
9292
if r in storage_map:
9393
assert storage_map[r] is storage, (
9494
"Given input_storage conflicts "
@@ -108,7 +108,7 @@ def map_storage(
108108
# allocate output storage
109109
if output_storage is not None:
110110
assert len(fgraph.outputs) == len(output_storage)
111-
for r, storage in zip(fgraph.outputs, output_storage):
111+
for r, storage in zip(fgraph.outputs, output_storage, strict=True):
112112
if r in storage_map:
113113
assert storage_map[r] is storage, (
114114
"Given output_storage confl"
@@ -191,7 +191,7 @@ def streamline_default_f():
191191
x[0] = None
192192
try:
193193
for thunk, node, old_storage in zip(
194-
thunks, order, post_thunk_old_storage
194+
thunks, order, post_thunk_old_storage, strict=True
195195
):
196196
thunk()
197197
for old_s in old_storage:
@@ -206,7 +206,7 @@ def streamline_nice_errors_f():
206206
for x in no_recycling:
207207
x[0] = None
208208
try:
209-
for thunk, node in zip(thunks, order):
209+
for thunk, node in zip(thunks, order, strict=True):
210210
thunk()
211211
except Exception:
212212
raise_with_op(fgraph, node, thunk)

pytensor/link/vm.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def clear_storage(self):
244244
def update_profile(self, profile):
245245
"""Update a profile object."""
246246
for node, thunk, t, c in zip(
247-
self.nodes, self.thunks, self.call_times, self.call_counts
247+
self.nodes, self.thunks, self.call_times, self.call_counts, strict=True
248248
):
249249
profile.apply_time.setdefault((self.fgraph, node), 0.0)
250250
profile.apply_time[(self.fgraph, node)] += t
@@ -312,7 +312,9 @@ def __init__(
312312
self.output_storage = output_storage
313313
self.inp_storage_and_out_idx = tuple(
314314
(inp_storage, self.fgraph.outputs.index(update_vars[inp]))
315-
for inp, inp_storage in zip(self.fgraph.inputs, self.input_storage)
315+
for inp, inp_storage in zip(
316+
self.fgraph.inputs, self.input_storage, strict=True
317+
)
316318
if inp in update_vars
317319
)
318320

@@ -1250,7 +1252,7 @@ def make_all(
12501252
self.profile.linker_node_make_thunks += t1 - t0
12511253
self.profile.linker_make_thunk_time = linker_make_thunk_time
12521254

1253-
for node, thunk in zip(order, thunks):
1255+
for node, thunk in zip(order, thunks, strict=True):
12541256
thunk.inputs = [storage_map[v] for v in node.inputs]
12551257
thunk.outputs = [storage_map[v] for v in node.outputs]
12561258

@@ -1306,11 +1308,11 @@ def make_all(
13061308
vm,
13071309
[
13081310
Container(input, storage)
1309-
for input, storage in zip(fgraph.inputs, input_storage)
1311+
for input, storage in zip(fgraph.inputs, input_storage, strict=True)
13101312
],
13111313
[
13121314
Container(output, storage, readonly=True)
1313-
for output, storage in zip(fgraph.outputs, output_storage)
1315+
for output, storage in zip(fgraph.outputs, output_storage, strict=True)
13141316
],
13151317
thunks,
13161318
order,

0 commit comments

Comments
 (0)