Skip to content

Commit 643c973

Browse files
Add a Numba OpFromGraph implementation
1 parent 308969c commit 643c973

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

aesara/link/numba/dispatch/basic.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from numba.extending import box
1818

1919
from aesara import config
20+
from aesara.compile.builders import OpFromGraph
2021
from aesara.compile.ops import DeepCopyOp
2122
from aesara.graph.basic import Apply, NoParams
2223
from aesara.graph.fg import FunctionGraph
@@ -374,6 +375,25 @@ def perform(*inputs):
374375
return perform
375376

376377

378+
@numba_funcify.register(OpFromGraph)
379+
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
380+
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
381+
382+
if len(op.fgraph.outputs) == 1:
383+
384+
@numba_njit
385+
def opfromgraph(*inputs):
386+
return fgraph_fn(*inputs)[0]
387+
388+
else:
389+
390+
@numba_njit
391+
def opfromgraph(*inputs):
392+
return fgraph_fn(*inputs)
393+
394+
return opfromgraph
395+
396+
377397
@numba_funcify.register(FunctionGraph)
378398
def numba_funcify_FunctionGraph(
379399
fgraph,

tests/link/numba/test_basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import aesara.tensor as at
1313
import aesara.tensor.math as aem
1414
from aesara import config, shared
15+
from aesara.compile.builders import OpFromGraph
1516
from aesara.compile.function import function
1617
from aesara.compile.mode import Mode
1718
from aesara.compile.ops import ViewOp
@@ -1003,3 +1004,18 @@ def test_scalar_return_value_conversion():
10031004
mode=numba_mode,
10041005
)
10051006
assert isinstance(x_fn(1.0), np.ndarray)
1007+
1008+
1009+
def test_OpFromGraph():
1010+
x, y, z = at.matrices("xyz")
1011+
ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
1012+
ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False)
1013+
1014+
o1, o2 = ofg_2(y, z)
1015+
out = ofg_1(x, o1) + o2
1016+
1017+
xv = np.ones((2, 2), dtype=config.floatX)
1018+
yv = np.ones((2, 2), dtype=config.floatX) * 3
1019+
zv = np.ones((2, 2), dtype=config.floatX) * 5
1020+
1021+
compare_numba_and_py(((x, y, z), (out,)), [xv, yv, zv])

0 commit comments

Comments
 (0)