Skip to content

Commit 984adeb

Browse files
committed
Add exceptions for hot loops
1 parent dac36f9 commit 984adeb

File tree

11 files changed

+35
-25
lines changed

11 files changed

+35
-25
lines changed

pytensor/compile/builders.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -863,5 +863,6 @@ def clone(self):
863863
def perform(self, node, inputs, outputs):
864864
variables = self.fn(*inputs)
865865
assert len(variables) == len(outputs)
866-
for output, variable in zip(outputs, variables, strict=True):
866+
# strict=False because asserted above
867+
for output, variable in zip(outputs, variables, strict=False):
867868
output[0] = variable

pytensor/compile/function/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1002,8 +1002,9 @@ def __call__(self, *args, **kwargs):
10021002
# if we are allowing garbage collection, remove the
10031003
# output reference from the internal storage cells
10041004
if getattr(self.vm, "allow_gc", False):
1005+
# strict=False because we are in a hot loop
10051006
for o_container, o_variable in zip(
1006-
self.output_storage, self.maker.fgraph.outputs, strict=True
1007+
self.output_storage, self.maker.fgraph.outputs, strict=False
10071008
):
10081009
if o_variable.owner is not None:
10091010
# this node is the variable of computation

pytensor/ifelse.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ def thunk():
305305
if len(ls) > 0:
306306
return ls
307307
else:
308-
for out, t in zip(outputs, input_true_branch, strict=True):
308+
# strict=False because we are in a hot loop
309+
for out, t in zip(outputs, input_true_branch, strict=False):
309310
compute_map[out][0] = 1
310311
val = storage_map[t][0]
311312
if self.as_view:
@@ -325,7 +326,8 @@ def thunk():
325326
if len(ls) > 0:
326327
return ls
327328
else:
328-
for out, f in zip(outputs, inputs_false_branch, strict=True):
329+
# strict=False because we are in a hot loop
330+
for out, f in zip(outputs, inputs_false_branch, strict=False):
329331
compute_map[out][0] = 1
330332
# can't view both outputs unless destroyhandler
331333
# improves

pytensor/link/basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,8 @@ def f():
544544
for x in to_reset:
545545
x[0] = None
546546
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
547-
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=True)):
547+
# strict=False because we are in a hot loop
548+
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)):
548549
try:
549550
wrapper(self.fgraph, i, node, *thunks)
550551
except Exception:
@@ -666,8 +667,9 @@ def thunk(
666667
):
667668
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
668669

670+
# strict=False because we are in a hot loop
669671
for o_var, o_storage, o_val in zip(
670-
fgraph.outputs, thunk_outputs, outputs, strict=True
672+
fgraph.outputs, thunk_outputs, outputs, strict=False
671673
):
672674
compute_map[o_var][0] = True
673675
o_storage[0] = self.output_filter(o_var, o_val)

pytensor/link/c/basic.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1993,25 +1993,26 @@ def make_thunk(self, **kwargs):
19931993
)
19941994

19951995
def f():
1996-
for input1, input2 in zip(i1, i2, strict=True):
1996+
# strict=False because we are in a hot loop
1997+
for input1, input2 in zip(i1, i2, strict=False):
19971998
# Set the inputs to be the same in both branches.
19981999
# The copy is necessary in order for inplace ops not to
19992000
# interfere.
20002001
input2.storage[0] = copy(input1.storage[0])
20012002
for thunk1, thunk2, node1, node2 in zip(
2002-
thunks1, thunks2, order1, order2, strict=True
2003+
thunks1, thunks2, order1, order2, strict=False
20032004
):
2004-
for output, storage in zip(node1.outputs, thunk1.outputs, strict=True):
2005+
for output, storage in zip(node1.outputs, thunk1.outputs, strict=False):
20052006
if output in no_recycling:
20062007
storage[0] = None
2007-
for output, storage in zip(node2.outputs, thunk2.outputs, strict=True):
2008+
for output, storage in zip(node2.outputs, thunk2.outputs, strict=False):
20082009
if output in no_recycling:
20092010
storage[0] = None
20102011
try:
20112012
thunk1()
20122013
thunk2()
20132014
for output1, output2 in zip(
2014-
thunk1.outputs, thunk2.outputs, strict=True
2015+
thunk1.outputs, thunk2.outputs, strict=False
20152016
):
20162017
self.checker(output1, output2)
20172018
except Exception:

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@ def advancedincsubtensor1_inplace(x, val, idxs):
158158
def advancedincsubtensor1_inplace(x, vals, idxs):
159159
if not len(idxs) == len(vals):
160160
raise ValueError("The number of indices and values must match.")
161-
for idx, val in zip(idxs, vals, strict=True):
161+
# strict=False to not add overhead in the jitted code
162+
# TODO: check
163+
for idx, val in zip(idxs, vals, strict=False):
162164
x[idx] = val
163165
return x
164166
else:

pytensor/link/pytorch/dispatch/shape.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def shape_i(x):
3434
def pytorch_funcify_SpecifyShape(op, node, **kwargs):
3535
def specifyshape(x, *shape):
3636
assert x.ndim == len(shape)
37-
for actual, expected in zip(x.shape, shape, strict=True):
37+
# strict=False because asserted above
38+
for actual, expected in zip(x.shape, shape, strict=False):
3839
if expected is None:
3940
continue
4041
if actual != expected:

pytensor/scalar/basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,8 +1150,9 @@ def perform(self, node, inputs, output_storage):
11501150
else:
11511151
variables = from_return_values(self.impl(*inputs))
11521152
assert len(variables) == len(output_storage)
1153+
# strict=False because we are in a hot loop
11531154
for out, storage, variable in zip(
1154-
node.outputs, output_storage, variables, strict=True
1155+
node.outputs, output_storage, variables, strict=False
11551156
):
11561157
dtype = out.dtype
11571158
storage[0] = self._cast_scalar(variable, dtype)
@@ -4328,7 +4329,8 @@ def make_node(self, *inputs):
43284329

43294330
def perform(self, node, inputs, output_storage):
43304331
outputs = self.py_perform_fn(*inputs)
4331-
for storage, out_val in zip(output_storage, outputs, strict=True):
4332+
# strict=False because we are in a hot loop
4333+
for storage, out_val in zip(output_storage, outputs, strict=False):
43324334
storage[0] = out_val
43334335

43344336
def grad(self, inputs, output_grads):

pytensor/scalar/loop.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _validate_updates(
9393
)
9494
else:
9595
update = outputs
96-
for i, u in zip(init[: len(update)], update, strict=True):
96+
for i, u in zip(init, update, strict=False):
9797
if i.type != u.type:
9898
raise TypeError(
9999
"Init and update types must be the same: "
@@ -207,7 +207,8 @@ def perform(self, node, inputs, output_storage):
207207
for i in range(n_steps):
208208
carry = inner_fn(*carry, *constant)
209209

210-
for storage, out_val in zip(output_storage, carry, strict=True):
210+
# strict=False because we are in a hot loop
211+
for storage, out_val in zip(output_storage, carry, strict=False):
211212
storage[0] = out_val
212213

213214
@property

pytensor/tensor/rewriting/subtensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ def local_subtensor_of_alloc(fgraph, node):
683683
# Slices to take from val
684684
val_slices = []
685685

686-
for i, (sl, dim) in enumerate(zip(slices, dims[: len(slices)], strict=True)):
686+
for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)):
687687
# If val was not copied over that dim,
688688
# we need to take the appropriate subtensor on it.
689689
if i >= n_added_dims:

0 commit comments

Comments
 (0)