Skip to content

Commit 3ed2c49

Browse files
authored
Enable no-cpython-wrapper in numba where possible (#765)
* Enable no-cpython-wrapper in numba where possible * Fix test with no_cpython_wrapper * Add docstring to numba_funcify
1 parent 15b90be commit 3ed2c49

File tree

4 files changed

+28
-5
lines changed

4 files changed

+28
-5
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def global_numba_func(func):
5959

6060
def numba_njit(*args, **kwargs):
6161
kwargs.setdefault("cache", config.numba__cache)
62+
kwargs.setdefault("no_cpython_wrapper", True)
63+
kwargs.setdefault("no_cfunc_wrapper", True)
6264

6365
# Supress caching warnings
6466
warnings.filterwarnings(
@@ -419,7 +421,12 @@ def perform(*inputs):
419421

420422
@singledispatch
421423
def numba_funcify(op, node=None, storage_map=None, **kwargs):
422-
"""Generate a numba function for a given op and apply node."""
424+
"""Generate a numba function for a given op and apply node.
425+
426+
The resulting function will usually use the `no_cpython_wrapper`
427+
argument in numba, so it can not be called directly from python,
428+
but only from other jit functions.
429+
"""
423430
return generate_fallback_impl(op, node, storage_map, **kwargs)
424431

425432

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,9 @@ def axis_apply_fn(x):
470470
"afn", # Approximate functions
471471
"reassoc",
472472
"nsz", # TODO Do we want this one?
473-
}
473+
},
474+
"no_cpython_wrapper": True,
475+
"no_cfunc_wrapper": True,
474476
}
475477

476478

@@ -698,7 +700,14 @@ def elemwise(*inputs):
698700
return tuple(outputs_summed)
699701
return outputs_summed[0]
700702

701-
@overload(elemwise)
703+
@overload(
704+
elemwise,
705+
jit_options={
706+
"fastmath": flags,
707+
"no_cpython_wrapper": True,
708+
"no_cfunc_wrapper": True,
709+
},
710+
)
702711
def ov_elemwise(*inputs):
703712
return elemwise_wrapper
704713

pytensor/link/numba/linker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def fgraph_convert(self, fgraph, **kwargs):
2929
def jit_compile(self, fn):
3030
from pytensor.link.numba.dispatch.basic import numba_njit
3131

32-
jitted_fn = numba_njit(fn)
32+
jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False)
3333
return jitted_fn
3434

3535
def create_thunk_inputs(self, storage_map):

tests/link/numba/test_tensor_basic.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,8 @@ def test_ExtractDiag(val, offset):
386386
)
387387
@pytest.mark.parametrize("reverse_axis", (False, True))
388388
def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
389+
from pytensor.link.numba.dispatch.basic import numba_njit
390+
389391
if reverse_axis:
390392
axis1, axis2 = axis2, axis1
391393

@@ -394,7 +396,12 @@ def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
394396
x_test = np.arange(np.prod(x_shape)).reshape(x_shape)
395397
out = pt.diagonal(x, k, axis1, axis2)
396398
numba_fn = numba_funcify(out.owner.op, out.owner)
397-
np.testing.assert_allclose(numba_fn(x_test), np.diagonal(x_test, k, axis1, axis2))
399+
400+
@numba_njit(no_cpython_wrapper=False)
401+
def wrap(x):
402+
return numba_fn(x)
403+
404+
np.testing.assert_allclose(wrap(x_test), np.diagonal(x_test, k, axis1, axis2))
398405

399406

400407
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)