Skip to content

Commit 88ef60d

Browse files
committed
Reduce pytensor function call overhead
1 parent 4e59f21 commit 88ef60d

File tree

12 files changed

+246
-364
lines changed

12 files changed

+246
-364
lines changed

pytensor/compile/builders.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,5 @@ def clone(self):
873873

874874
def perform(self, node, inputs, outputs):
875875
variables = self.fn(*inputs)
876-
assert len(variables) == len(outputs)
877-
# strict=False because asserted above
878-
for output, variable in zip(outputs, variables, strict=False):
876+
for output, variable in zip(outputs, variables, strict=True):
879877
output[0] = variable

pytensor/compile/function/types.py

Lines changed: 201 additions & 325 deletions
Large diffs are not rendered by default.

pytensor/link/basic.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,6 @@ def __set__(self, value: Any) -> None:
8787
if self.readonly:
8888
raise Exception(f"Cannot set readonly storage: {self.name}")
8989
try:
90-
if value is None:
91-
self.storage[0] = None
92-
return
93-
9490
kwargs = {}
9591
if self.strict:
9692
kwargs["strict"] = True
@@ -539,14 +535,12 @@ def make_thunk(self, **kwargs):
539535

540536
def f():
541537
for inputs in input_lists[1:]:
542-
# strict=False because we are in a hot loop
543-
for input1, input2 in zip(inputs0, inputs, strict=False):
538+
for input1, input2 in zip(inputs0, inputs, strict=True):
544539
input2.storage[0] = copy(input1.storage[0])
545540
for x in to_reset:
546541
x[0] = None
547542
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
548-
# strict=False because we are in a hot loop
549-
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)):
543+
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=True)):
550544
try:
551545
wrapper(self.fgraph, i, node, *thunks)
552546
except Exception:
@@ -668,10 +662,12 @@ def thunk(
668662
# since the error may come from any of them?
669663
raise_with_op(self.fgraph, output_nodes[0], thunk)
670664

671-
# strict=False because we are in a hot loop
672-
for o_storage, o_val in zip(thunk_outputs, outputs, strict=False):
665+
# strict=None because we are in a hot loop
666+
for o_storage, o_val in zip(thunk_outputs, outputs): # noqa: B905
673667
o_storage[0] = o_val
674668

669+
return outputs
670+
675671
thunk.inputs = thunk_inputs
676672
thunk.outputs = thunk_outputs
677673
thunk.lazy = False

pytensor/link/c/basic.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,26 +1988,25 @@ def make_thunk(self, **kwargs):
19881988
)
19891989

19901990
def f():
1991-
# strict=False because we are in a hot loop
1992-
for input1, input2 in zip(i1, i2, strict=False):
1991+
for input1, input2 in zip(i1, i2, strict=True):
19931992
# Set the inputs to be the same in both branches.
19941993
# The copy is necessary in order for inplace ops not to
19951994
# interfere.
19961995
input2.storage[0] = copy(input1.storage[0])
19971996
for thunk1, thunk2, node1, node2 in zip(
1998-
thunks1, thunks2, order1, order2, strict=False
1997+
thunks1, thunks2, order1, order2, strict=True
19991998
):
2000-
for output, storage in zip(node1.outputs, thunk1.outputs, strict=False):
1999+
for output, storage in zip(node1.outputs, thunk1.outputs, strict=True):
20012000
if output in no_recycling:
20022001
storage[0] = None
2003-
for output, storage in zip(node2.outputs, thunk2.outputs, strict=False):
2002+
for output, storage in zip(node2.outputs, thunk2.outputs, strict=True):
20042003
if output in no_recycling:
20052004
storage[0] = None
20062005
try:
20072006
thunk1()
20082007
thunk2()
20092008
for output1, output2 in zip(
2010-
thunk1.outputs, thunk2.outputs, strict=False
2009+
thunk1.outputs, thunk2.outputs, strict=True
20112010
):
20122011
self.checker(output1, output2)
20132012
except Exception:

pytensor/link/numba/dispatch/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,10 @@ def py_perform_return(inputs):
312312
else:
313313

314314
def py_perform_return(inputs):
315-
# strict=False because we are in a hot loop
315+
# strict=None because we are in a hot loop
316316
return tuple(
317317
out_type.filter(out[0])
318-
for out_type, out in zip(output_types, py_perform(inputs), strict=False)
318+
for out_type, out in zip(output_types, py_perform(inputs)) # noqa: B905
319319
)
320320

321321
@numba_njit

pytensor/link/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def streamline_nice_errors_f():
207207
for x in no_recycling:
208208
x[0] = None
209209
try:
210-
# strict=False because we are in a hot loop
211-
for thunk, node in zip(thunks, order, strict=False):
210+
# strict=None because we are in a hot loop
211+
for thunk, node in zip(thunks, order): # noqa: B905
212212
thunk()
213213
except Exception:
214214
raise_with_op(fgraph, node, thunk)

pytensor/scalar/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4427,8 +4427,8 @@ def make_node(self, *inputs):
44274427

44284428
def perform(self, node, inputs, output_storage):
44294429
outputs = self.py_perform_fn(*inputs)
4430-
# strict=False because we are in a hot loop
4431-
for storage, out_val in zip(output_storage, outputs, strict=False):
4430+
# strict=None because we are in a hot loop
4431+
for storage, out_val in zip(output_storage, outputs): # noqa: B905
44324432
storage[0] = out_val
44334433

44344434
def grad(self, inputs, output_grads):

pytensor/scalar/loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def perform(self, node, inputs, output_storage):
207207
for i in range(n_steps):
208208
carry = inner_fn(*carry, *constant)
209209

210-
# strict=False because we are in a hot loop
211-
for storage, out_val in zip(output_storage, carry, strict=False):
210+
# strict=None because we are in a hot loop
211+
for storage, out_val in zip(output_storage, carry): # noqa: B905
212212
storage[0] = out_val
213213

214214
@property

pytensor/tensor/random/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,8 +1865,8 @@ def rng_fn(cls, rng, p, size):
18651865
# to `p.shape[:-1]` in the call to `vsearchsorted` below.
18661866
if len(size) < (p.ndim - 1):
18671867
raise ValueError("`size` is incompatible with the shape of `p`")
1868-
# strict=False because we are in a hot loop
1869-
for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=False):
1868+
# strict=None because we are in a hot loop
1869+
for s, ps in zip(reversed(size), reversed(p.shape[:-1])): # noqa: B905
18701870
if s == 1 and ps != 1:
18711871
raise ValueError("`size` is incompatible with the shape of `p`")
18721872

pytensor/tensor/random/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def params_broadcast_shapes(
4444
max_fn = maximum if use_pytensor else max
4545

4646
rev_extra_dims: list[int] = []
47-
# strict=False because we are in a hot loop
48-
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False):
47+
# strict=None because we are in a hot loop
48+
for ndim_param, param_shape in zip(ndims_params, param_shapes): # noqa: B905
4949
# We need this in order to use `len`
5050
param_shape = tuple(param_shape)
5151
extras = tuple(param_shape[: (len(param_shape) - ndim_param)])
@@ -69,7 +69,7 @@ def max_bcast(x, y):
6969
(extra_dims + tuple(param_shape)[-ndim_param:])
7070
if ndim_param > 0
7171
else extra_dims
72-
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False)
72+
for ndim_param, param_shape in zip(ndims_params, param_shapes) # noqa: B905
7373
]
7474

7575
return bcast_shapes
@@ -127,10 +127,10 @@ def broadcast_params(
127127
)
128128
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to
129129

130-
# strict=False because we are in a hot loop
130+
# strict=None because we are in a hot loop
131131
bcast_params = [
132132
broadcast_to_fn(param, shape)
133-
for shape, param in zip(shapes, params, strict=False)
133+
for shape, param in zip(shapes, params) # noqa: B905
134134
]
135135

136136
return bcast_params

0 commit comments

Comments
 (0)