Skip to content

Commit 83828bf

Browse files
committed
Make zips strict in pytensor/link
1 parent a850066 commit 83828bf

File tree

3 files changed

+26
-20
lines changed

3 files changed

+26
-20
lines changed

pytensor/link/basic.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -385,11 +385,11 @@ def make_all(
385385
f,
386386
[
387387
Container(input, storage)
388-
for input, storage in zip(fgraph.inputs, input_storage)
388+
for input, storage in zip(fgraph.inputs, input_storage, strict=True)
389389
],
390390
[
391391
Container(output, storage, readonly=True)
392-
for output, storage in zip(fgraph.outputs, output_storage)
392+
for output, storage in zip(fgraph.outputs, output_storage, strict=True)
393393
],
394394
thunks,
395395
order,
@@ -509,7 +509,9 @@ def make_thunk(self, **kwargs):
509509
kwargs.pop("input_storage", None)
510510
make_all += [x.make_all(**kwargs) for x in self.linkers[1:]]
511511

512-
fns, input_lists, output_lists, thunk_lists, order_lists = zip(*make_all)
512+
fns, input_lists, output_lists, thunk_lists, order_lists = zip(
513+
*make_all, strict=True
514+
)
513515

514516
order_list0 = order_lists[0]
515517
for order_list in order_lists[1:]:
@@ -521,12 +523,12 @@ def make_thunk(self, **kwargs):
521523
inputs0 = input_lists[0]
522524
outputs0 = output_lists[0]
523525

524-
thunk_groups = list(zip(*thunk_lists))
525-
order = [x[0] for x in zip(*order_lists)]
526+
thunk_groups = list(zip(*thunk_lists, strict=True))
527+
order = [x[0] for x in zip(*order_lists, strict=True)]
526528

527529
to_reset = [
528530
thunk.outputs[j]
529-
for thunks, node in zip(thunk_groups, order)
531+
for thunks, node in zip(thunk_groups, order, strict=True)
530532
for j, output in enumerate(node.outputs)
531533
if output in no_recycling
532534
for thunk in thunks
@@ -537,12 +539,12 @@ def make_thunk(self, **kwargs):
537539

538540
def f():
539541
for inputs in input_lists[1:]:
540-
for input1, input2 in zip(inputs0, inputs):
542+
for input1, input2 in zip(inputs0, inputs, strict=True):
541543
input2.storage[0] = copy(input1.storage[0])
542544
for x in to_reset:
543545
x[0] = None
544546
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
545-
for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
547+
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=True)):
546548
try:
547549
wrapper(self.fgraph, i, node, *thunks)
548550
except Exception:
@@ -664,7 +666,9 @@ def thunk(
664666
):
665667
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
666668

667-
for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
669+
for o_var, o_storage, o_val in zip(
670+
fgraph.outputs, thunk_outputs, outputs, strict=True
671+
):
668672
compute_map[o_var][0] = True
669673
o_storage[0] = self.output_filter(o_var, o_val)
670674
return outputs
@@ -730,11 +734,11 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
730734
fn,
731735
[
732736
Container(input, storage)
733-
for input, storage in zip(fgraph.inputs, input_storage)
737+
for input, storage in zip(fgraph.inputs, input_storage, strict=True)
734738
],
735739
[
736740
Container(output, storage, readonly=True)
737-
for output, storage in zip(fgraph.outputs, output_storage)
741+
for output, storage in zip(fgraph.outputs, output_storage, strict=True)
738742
],
739743
thunks,
740744
nodes,

pytensor/link/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def map_storage(
8888
assert len(fgraph.inputs) == len(input_storage)
8989

9090
# add input storage into storage_map
91-
for r, storage in zip(fgraph.inputs, input_storage):
91+
for r, storage in zip(fgraph.inputs, input_storage, strict=True):
9292
if r in storage_map:
9393
assert storage_map[r] is storage, (
9494
"Given input_storage conflicts "
@@ -108,7 +108,7 @@ def map_storage(
108108
# allocate output storage
109109
if output_storage is not None:
110110
assert len(fgraph.outputs) == len(output_storage)
111-
for r, storage in zip(fgraph.outputs, output_storage):
111+
for r, storage in zip(fgraph.outputs, output_storage, strict=True):
112112
if r in storage_map:
113113
assert storage_map[r] is storage, (
114114
"Given output_storage confl"
@@ -191,7 +191,7 @@ def streamline_default_f():
191191
x[0] = None
192192
try:
193193
for thunk, node, old_storage in zip(
194-
thunks, order, post_thunk_old_storage
194+
thunks, order, post_thunk_old_storage, strict=True
195195
):
196196
thunk()
197197
for old_s in old_storage:
@@ -206,7 +206,7 @@ def streamline_nice_errors_f():
206206
for x in no_recycling:
207207
x[0] = None
208208
try:
209-
for thunk, node in zip(thunks, order):
209+
for thunk, node in zip(thunks, order, strict=True):
210210
thunk()
211211
except Exception:
212212
raise_with_op(fgraph, node, thunk)

pytensor/link/vm.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def clear_storage(self):
244244
def update_profile(self, profile):
245245
"""Update a profile object."""
246246
for node, thunk, t, c in zip(
247-
self.nodes, self.thunks, self.call_times, self.call_counts
247+
self.nodes, self.thunks, self.call_times, self.call_counts, strict=True
248248
):
249249
profile.apply_time[(self.fgraph, node)] += t
250250

@@ -310,7 +310,9 @@ def __init__(
310310
self.output_storage = output_storage
311311
self.inp_storage_and_out_idx = tuple(
312312
(inp_storage, self.fgraph.outputs.index(update_vars[inp]))
313-
for inp, inp_storage in zip(self.fgraph.inputs, self.input_storage)
313+
for inp, inp_storage in zip(
314+
self.fgraph.inputs, self.input_storage, strict=True
315+
)
314316
if inp in update_vars
315317
)
316318

@@ -1247,7 +1249,7 @@ def make_all(
12471249
self.profile.linker_node_make_thunks += t1 - t0
12481250
self.profile.linker_make_thunk_time = linker_make_thunk_time
12491251

1250-
for node, thunk in zip(order, thunks):
1252+
for node, thunk in zip(order, thunks, strict=True):
12511253
thunk.inputs = [storage_map[v] for v in node.inputs]
12521254
thunk.outputs = [storage_map[v] for v in node.outputs]
12531255

@@ -1303,11 +1305,11 @@ def make_all(
13031305
vm,
13041306
[
13051307
Container(input, storage)
1306-
for input, storage in zip(fgraph.inputs, input_storage)
1308+
for input, storage in zip(fgraph.inputs, input_storage, strict=True)
13071309
],
13081310
[
13091311
Container(output, storage, readonly=True)
1310-
for output, storage in zip(fgraph.outputs, output_storage)
1312+
for output, storage in zip(fgraph.outputs, output_storage, strict=True)
13111313
],
13121314
thunks,
13131315
order,

0 commit comments

Comments
 (0)