Skip to content

Commit 74156ec

Browse files
committed
Speedup python blockwise
1 parent 60bfb3b commit 74156ec

File tree

1 file changed

+75
-38
lines changed

1 file changed

+75
-38
lines changed

pytensor/tensor/blockwise.py

Lines changed: 75 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from collections.abc import Sequence
1+
from collections.abc import Callable, Sequence
22
from typing import Any, cast
33

44
import numpy as np
5+
from numpy import broadcast_shapes, empty, ndindex, nditer
56

67
from pytensor import config
78
from pytensor.compile.builders import OpFromGraph
@@ -28,6 +29,67 @@
2829
from pytensor.tensor.variable import TensorVariable
2930

3031

32+
def _vectorize_node_perform(core_node, batch_ndim: int):
33+
core_op_perform = core_node.op.perform
34+
n_outs = len(core_node.outputs)
35+
36+
def vectorized_func(
37+
*args,
38+
core_node=core_node,
39+
core_op_perform=core_op_perform,
40+
batch_ndim=batch_ndim,
41+
n_outs=n_outs,
42+
):
43+
batch_shape = broadcast_shapes(*(arg.shape[:batch_ndim] for arg in args))
44+
args = list(args)
45+
for i, arg in enumerate(args):
46+
if arg.shape[:batch_ndim] != batch_shape:
47+
# Main logic of `np.broadcast_to`
48+
it = nditer(
49+
(arg,),
50+
flags=["multi_index", "zerosize_ok"],
51+
op_flags=["readonly"],
52+
itershape=batch_shape + arg.shape[batch_ndim:],
53+
order="C",
54+
)
55+
with it:
56+
args[i] = it.itviews[0]
57+
58+
core_output_storage = [[None] for _ in range(n_outs)]
59+
ndindex_iterator = ndindex(*batch_shape)
60+
# Call once to get the output shapes
61+
try:
62+
# TODO: Pass core shape as input like BlockwiseWithCoreShape does?
63+
index0 = next(ndindex_iterator)
64+
except StopIteration:
65+
raise NotImplementedError("vectorize with zero iterations not implemented")
66+
else:
67+
core_op_perform(
68+
core_node,
69+
[np.asarray(arg[index0]) for arg in args],
70+
core_output_storage,
71+
)
72+
outputs = tuple(
73+
empty(batch_shape + core_output[0].shape, dtype=core_output[0].dtype)
74+
for core_output in core_output_storage
75+
)
76+
for output, core_output in zip(outputs, core_output_storage): # noqa: B905
77+
output[index0] = core_output[0]
78+
79+
for index in ndindex_iterator:
80+
core_op_perform(
81+
core_node,
82+
[np.asarray(a[index]) for a in args],
83+
core_output_storage,
84+
)
85+
for output, core_output in zip(outputs, core_output_storage): # noqa: B905
86+
output[index] = core_output[0]
87+
88+
return outputs
89+
90+
return vectorized_func
91+
92+
3193
class Blockwise(Op):
3294
"""Generalizes a core `Op` to work with batched dimensions.
3395
@@ -308,46 +370,29 @@ def L_op(self, inputs, outs, ograds):
308370

309371
return rval
310372

311-
def _create_node_gufunc(self, node) -> None:
373+
def _create_node_gufunc(self, node: Apply) -> Callable:
312374
"""Define (or retrieve) the node gufunc used in `perform`.
313375
314376
If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly.
315377
Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node.
316378
317379
The gufunc is stored in the tag of the node.
318380
"""
319-
gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None)
320-
321-
if gufunc_spec is not None:
381+
if (
382+
gufunc_spec := self.gufunc_spec
383+
or getattr(self.core_op, "gufunc_spec", None)
384+
) is not None:
322385
gufunc = import_func_from_string(gufunc_spec[0])
323386
if gufunc is None:
324387
raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")
325-
326388
else:
327-
# Wrap core_op perform method in numpy vectorize
328-
n_outs = len(self.outputs_sig)
329389
core_node = self._create_dummy_core_node(node.inputs)
330-
inner_outputs_storage = [[None] for _ in range(n_outs)]
331-
332-
def core_func(
333-
*inner_inputs,
334-
core_node=core_node,
335-
inner_outputs_storage=inner_outputs_storage,
336-
):
337-
self.core_op.perform(
338-
core_node,
339-
[np.asarray(inp) for inp in inner_inputs],
340-
inner_outputs_storage,
341-
)
342-
343-
if n_outs == 1:
344-
return inner_outputs_storage[0][0]
345-
else:
346-
return tuple(r[0] for r in inner_outputs_storage)
347-
348-
gufunc = np.vectorize(core_func, signature=self.signature)
390+
gufunc = _vectorize_node_perform(
391+
core_node,
392+
batch_ndim=self.batch_ndim(node),
393+
)
349394

350-
node.tag.gufunc = gufunc
395+
return gufunc
351396

352397
def _check_runtime_broadcast(self, node, inputs):
353398
batch_ndim = self.batch_ndim(node)
@@ -375,23 +420,15 @@ def perform(self, node, inputs, output_storage):
375420
gufunc = getattr(node.tag, "gufunc", None)
376421

377422
if gufunc is None:
378-
# Cache it once per node
379-
self._create_node_gufunc(node)
380-
gufunc = node.tag.gufunc
423+
gufunc = node.tag.gufunc = self._create_node_gufunc(node)
381424

382425
self._check_runtime_broadcast(node, inputs)
383426

384427
res = gufunc(*inputs)
385428
if not isinstance(res, tuple):
386429
res = (res,)
387430

388-
# strict=False because we are in a hot loop
389-
for node_out, out_storage, r in zip(
390-
node.outputs, output_storage, res, strict=False
391-
):
392-
out_dtype = getattr(node_out, "dtype", None)
393-
if out_dtype and out_dtype != r.dtype:
394-
r = np.asarray(r, dtype=out_dtype)
431+
for node_out, out_storage, r in zip(node.outputs, output_storage, res): # noqa: B905
395432
out_storage[0] = r
396433

397434
def __str__(self):

0 commit comments

Comments
 (0)