Skip to content

Commit 5c8f09d

Browse files
committed
Add dpnp ufunc specific registry
1 parent 8d410e7 commit 5c8f09d

File tree

4 files changed

+44
-9
lines changed

4 files changed

+44
-9
lines changed

numba_dpex/core/targets/dpjit_target.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ class Dpex(CPU):
3232
class DpexTypingContext(typing.Context):
3333
"""Custom typing context to support dpjit compilation."""
3434

35+
def load_additional_registries(self):
36+
"""Register dpjit specific functions like dpnp ufuncs."""
37+
from numba_dpex.core.typing import dpnpdecl
38+
39+
self.install_registry(dpnpdecl.registry)
40+
super().load_additional_registries()
41+
3542

3643
class DpexTargetContext(CPUContext):
3744
def __init__(self, typingctx, target=DPEX_TARGET_NAME):

numba_dpex/core/targets/kernel_target.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,16 @@ def load_additional_registries(self):
111111
"""Register the OpenCL API and math and other functions."""
112112
from numba.core.typing import cmathdecl, enumdecl, npydecl
113113

114+
from numba_dpex.core.typing import dpnpdecl
115+
114116
from ...ocl import mathdecl, ocldecl
115117

116118
self.install_registry(ocldecl.registry)
117119
self.install_registry(mathdecl.registry)
118120
self.install_registry(cmathdecl.registry)
121+
# TODO: https://github.com/IntelPython/numba-dpex/issues/1270
119122
self.install_registry(npydecl.registry)
123+
self.install_registry(dpnpdecl.registry)
120124
self.install_registry(enumdecl.registry)
121125

122126

numba_dpex/core/typing/dpnpdecl.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,24 @@
1010
NumpyRulesArrayOperator,
1111
NumpyRulesInplaceArrayOperator,
1212
NumpyRulesUnaryArrayOperator,
13-
infer_global,
1413
)
14+
from numba.core.typing.templates import Registry
15+
16+
registry = Registry()
17+
infer = registry.register
18+
infer_global = registry.register_global
19+
infer_getattr = registry.register_attr
20+
21+
22+
def _install_operations(cls: NumpyRulesArrayOperator):
23+
for op, ufunc_name in cls._op_map.items():
24+
infer_global(op)(
25+
type(
26+
"NumpyRulesArrayOperator_" + ufunc_name,
27+
(cls,),
28+
dict(key=op),
29+
)
30+
)
1531

1632

1733
class DpnpRulesArrayOperator(NumpyRulesArrayOperator):
@@ -29,13 +45,21 @@ def ufunc(self):
2945
except:
3046
pass
3147

48+
@classmethod
49+
def install_operations(cls):
50+
_install_operations(cls)
51+
3252

3353
class DpnpRulesInplaceArrayOperator(NumpyRulesInplaceArrayOperator):
34-
pass
54+
@classmethod
55+
def install_operations(cls):
56+
_install_operations(cls)
3557

3658

3759
class DpnpRulesUnaryArrayOperator(NumpyRulesUnaryArrayOperator):
38-
pass
60+
@classmethod
61+
def install_operations(cls):
62+
_install_operations(cls)
3963

4064

4165
# list of unary ufuncs to register

numba_dpex/dpnp_iface/dpnpimpl.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from numba.core.imputils import Registry
99
from numba.np import npyimpl
1010

11-
from numba_dpex.core.typing.dpnpdecl import _unsupported
11+
from numba_dpex.core.typing import dpnpdecl
1212
from numba_dpex.dpnp_iface import dpnp_ufunc_db
1313

1414
registry = Registry("dpnpimpl")
@@ -36,11 +36,11 @@ def _register_dpnp_ufuncs():
3636
)
3737

3838
for _op_map in (
39-
npyimpl.npydecl.NumpyRulesUnaryArrayOperator._op_map,
40-
npyimpl.npydecl.NumpyRulesArrayOperator._op_map,
39+
dpnpdecl.DpnpRulesUnaryArrayOperator._op_map,
40+
dpnpdecl.DpnpRulesArrayOperator._op_map,
4141
):
4242
for operator, ufunc_name in _op_map.items():
43-
if ufunc_name in _unsupported:
43+
if ufunc_name in dpnpdecl._unsupported:
4444
continue
4545
ufunc = getattr(dpnp, ufunc_name)
4646
kernel = kernels[ufunc]
@@ -57,9 +57,9 @@ def _register_dpnp_ufuncs():
5757
"There shouldn't be any non-unary or binary operators"
5858
)
5959

60-
for _op_map in (npyimpl.npydecl.NumpyRulesInplaceArrayOperator._op_map,):
60+
for _op_map in (dpnpdecl.DpnpRulesInplaceArrayOperator._op_map,):
6161
for operator, ufunc_name in _op_map.items():
62-
if ufunc_name in _unsupported:
62+
if ufunc_name in dpnpdecl._unsupported:
6363
continue
6464
ufunc = getattr(dpnp, ufunc_name)
6565
kernel = kernels[ufunc]

0 commit comments

Comments
 (0)