Skip to content

Commit f8bc010

Browse files
committed
Add exceptions for hot loops
1 parent 8b75563 commit f8bc010

File tree

20 files changed

+80
-49
lines changed

20 files changed

+80
-49
lines changed

pytensor/compile/builders.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -863,5 +863,6 @@ def clone(self):
863863
def perform(self, node, inputs, outputs):
864864
variables = self.fn(*inputs)
865865
assert len(variables) == len(outputs)
866-
for output, variable in zip(outputs, variables, strict=True):
866+
# strict=False because asserted above
867+
for output, variable in zip(outputs, variables, strict=False):
867868
output[0] = variable

pytensor/compile/function/types.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,8 +1002,9 @@ def __call__(self, *args, **kwargs):
10021002
# if we are allowing garbage collection, remove the
10031003
# output reference from the internal storage cells
10041004
if getattr(self.vm, "allow_gc", False):
1005+
# strict=False because we are in a hot loop
10051006
for o_container, o_variable in zip(
1006-
self.output_storage, self.maker.fgraph.outputs, strict=True
1007+
self.output_storage, self.maker.fgraph.outputs, strict=False
10071008
):
10081009
if o_variable.owner is not None:
10091010
# this node is the variable of computation
@@ -1012,8 +1013,9 @@ def __call__(self, *args, **kwargs):
10121013

10131014
if getattr(self.vm, "need_update_inputs", True):
10141015
# Update the inputs that have an update function
1016+
# strict=False because we are in a hot loop
10151017
for input, storage in reversed(
1016-
list(zip(self.maker.expanded_inputs, input_storage, strict=True))
1018+
list(zip(self.maker.expanded_inputs, input_storage, strict=False))
10171019
):
10181020
if input.update is not None:
10191021
storage.data = outputs.pop()
@@ -1044,7 +1046,8 @@ def __call__(self, *args, **kwargs):
10441046
assert len(self.output_keys) == len(outputs)
10451047

10461048
if output_subset is None:
1047-
return dict(zip(self.output_keys, outputs, strict=True))
1049+
# strict=False because we are in a hot loop
1050+
return dict(zip(self.output_keys, outputs, strict=False))
10481051
else:
10491052
return {
10501053
self.output_keys[index]: outputs[index]
@@ -1111,8 +1114,9 @@ def _pickle_Function(f):
11111114
ins = list(f.input_storage)
11121115
input_storage = []
11131116

1117+
# strict=False because we are in a hot loop
11141118
for (input, indices, inputs), (required, refeed, default) in zip(
1115-
f.indices, f.defaults, strict=True
1119+
f.indices, f.defaults, strict=False
11161120
):
11171121
input_storage.append(ins[0])
11181122
del ins[0]

pytensor/ifelse.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ def thunk():
305305
if len(ls) > 0:
306306
return ls
307307
else:
308-
for out, t in zip(outputs, input_true_branch, strict=True):
308+
# strict=False because we are in a hot loop
309+
for out, t in zip(outputs, input_true_branch, strict=False):
309310
compute_map[out][0] = 1
310311
val = storage_map[t][0]
311312
if self.as_view:
@@ -325,7 +326,8 @@ def thunk():
325326
if len(ls) > 0:
326327
return ls
327328
else:
328-
for out, f in zip(outputs, inputs_false_branch, strict=True):
329+
# strict=False because we are in a hot loop
330+
for out, f in zip(outputs, inputs_false_branch, strict=False):
329331
compute_map[out][0] = 1
330332
# can't view both outputs unless destroyhandler
331333
# improves

pytensor/link/basic.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -539,12 +539,14 @@ def make_thunk(self, **kwargs):
539539

540540
def f():
541541
for inputs in input_lists[1:]:
542-
for input1, input2 in zip(inputs0, inputs, strict=True):
542+
# strict=False because we are in a hot loop
543+
for input1, input2 in zip(inputs0, inputs, strict=False):
543544
input2.storage[0] = copy(input1.storage[0])
544545
for x in to_reset:
545546
x[0] = None
546547
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
547-
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=True)):
548+
# strict=False because we are in a hot loop
549+
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)):
548550
try:
549551
wrapper(self.fgraph, i, node, *thunks)
550552
except Exception:
@@ -666,8 +668,9 @@ def thunk(
666668
):
667669
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
668670

671+
# strict=False because we are in a hot loop
669672
for o_var, o_storage, o_val in zip(
670-
fgraph.outputs, thunk_outputs, outputs, strict=True
673+
fgraph.outputs, thunk_outputs, outputs, strict=False
671674
):
672675
compute_map[o_var][0] = True
673676
o_storage[0] = self.output_filter(o_var, o_val)

pytensor/link/c/basic.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1993,25 +1993,26 @@ def make_thunk(self, **kwargs):
19931993
)
19941994

19951995
def f():
1996-
for input1, input2 in zip(i1, i2, strict=True):
1996+
# strict=False because we are in a hot loop
1997+
for input1, input2 in zip(i1, i2, strict=False):
19971998
# Set the inputs to be the same in both branches.
19981999
# The copy is necessary in order for inplace ops not to
19992000
# interfere.
20002001
input2.storage[0] = copy(input1.storage[0])
20012002
for thunk1, thunk2, node1, node2 in zip(
2002-
thunks1, thunks2, order1, order2, strict=True
2003+
thunks1, thunks2, order1, order2, strict=False
20032004
):
2004-
for output, storage in zip(node1.outputs, thunk1.outputs, strict=True):
2005+
for output, storage in zip(node1.outputs, thunk1.outputs, strict=False):
20052006
if output in no_recycling:
20062007
storage[0] = None
2007-
for output, storage in zip(node2.outputs, thunk2.outputs, strict=True):
2008+
for output, storage in zip(node2.outputs, thunk2.outputs, strict=False):
20082009
if output in no_recycling:
20092010
storage[0] = None
20102011
try:
20112012
thunk1()
20122013
thunk2()
20132014
for output1, output2 in zip(
2014-
thunk1.outputs, thunk2.outputs, strict=True
2015+
thunk1.outputs, thunk2.outputs, strict=False
20152016
):
20162017
self.checker(output1, output2)
20172018
except Exception:

pytensor/link/numba/dispatch/basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,10 @@ def py_perform_return(inputs):
401401
else:
402402

403403
def py_perform_return(inputs):
404+
# strict=False because we are in a hot loop
404405
return tuple(
405406
out_type.filter(out[0])
406-
for out_type, out in zip(output_types, py_perform(inputs), strict=True)
407+
for out_type, out in zip(output_types, py_perform(inputs), strict=False)
407408
)
408409

409410
@numba_njit

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@ def advancedincsubtensor1_inplace(x, val, idxs):
158158
def advancedincsubtensor1_inplace(x, vals, idxs):
159159
if not len(idxs) == len(vals):
160160
raise ValueError("The number of indices and values must match.")
161-
for idx, val in zip(idxs, vals, strict=True):
161+
# strict=False to not add overhead in the jitted code
162+
# TODO: check
163+
for idx, val in zip(idxs, vals, strict=False):
162164
x[idx] = val
163165
return x
164166
else:

pytensor/link/pytorch/dispatch/shape.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def shape_i(x):
3434
def pytorch_funcify_SpecifyShape(op, node, **kwargs):
3535
def specifyshape(x, *shape):
3636
assert x.ndim == len(shape)
37-
for actual, expected in zip(x.shape, shape, strict=True):
37+
# strict=False because asserted above
38+
for actual, expected in zip(x.shape, shape, strict=False):
3839
if expected is None:
3940
continue
4041
if actual != expected:

pytensor/link/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,9 @@ def streamline_default_f():
190190
for x in no_recycling:
191191
x[0] = None
192192
try:
193+
# strict=False because we are in a hot loop
193194
for thunk, node, old_storage in zip(
194-
thunks, order, post_thunk_old_storage, strict=True
195+
thunks, order, post_thunk_old_storage, strict=False
195196
):
196197
thunk()
197198
for old_s in old_storage:
@@ -206,7 +207,8 @@ def streamline_nice_errors_f():
206207
for x in no_recycling:
207208
x[0] = None
208209
try:
209-
for thunk, node in zip(thunks, order, strict=True):
210+
# strict=False because we are in a hot loop
211+
for thunk, node in zip(thunks, order, strict=False):
210212
thunk()
211213
except Exception:
212214
raise_with_op(fgraph, node, thunk)

pytensor/scalar/basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,8 +1150,9 @@ def perform(self, node, inputs, output_storage):
11501150
else:
11511151
variables = from_return_values(self.impl(*inputs))
11521152
assert len(variables) == len(output_storage)
1153+
# strict=False because we are in a hot loop
11531154
for out, storage, variable in zip(
1154-
node.outputs, output_storage, variables, strict=True
1155+
node.outputs, output_storage, variables, strict=False
11551156
):
11561157
dtype = out.dtype
11571158
storage[0] = self._cast_scalar(variable, dtype)
@@ -4328,7 +4329,8 @@ def make_node(self, *inputs):
43284329

43294330
def perform(self, node, inputs, output_storage):
43304331
outputs = self.py_perform_fn(*inputs)
4331-
for storage, out_val in zip(output_storage, outputs, strict=True):
4332+
# strict=False because we are in a hot loop
4333+
for storage, out_val in zip(output_storage, outputs, strict=False):
43324334
storage[0] = out_val
43334335

43344336
def grad(self, inputs, output_grads):

0 commit comments

Comments
 (0)