Skip to content

Commit 22f28f4

Browse files
committed
Move dpnp monkey patching to patch file
1 parent 9874e89 commit 22f28f4

File tree

5 files changed

+99
-101
lines changed

5 files changed

+99
-101
lines changed

numba_dpex/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from numba import __version__ as numba_version
1818

1919
from .kernel_api_impl.spirv import target as spirv_kernel_target
20-
from .numba_patches import patch_is_ufunc
20+
from .numba_patches import patch_ufuncs
2121
from .register_kernel_api_overloads import init_kernel_api_spirv_overloads
2222

2323

@@ -70,7 +70,7 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:
7070
dpctl_sem_version = parse_sem_version(dpctl.__version__)
7171

7272
# Monkey patches
73-
patch_is_ufunc.patch()
73+
patch_ufuncs.patch()
7474

7575
from numba import prange # noqa E402
7676

numba_dpex/core/typing/dpnpdecl.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import logging
6-
75
import dpnp
8-
import numpy as np
96
from numba.core import types
107
from numba.core.typing.npydecl import (
118
Numpy_rules_ufunc,
@@ -25,7 +22,7 @@ def _install_operations(cls: NumpyRulesArrayOperator):
2522
for op, ufunc_name in cls._op_map.items():
2623
infer_global(op)(
2724
type(
28-
"NumpyRulesArrayOperator_" + ufunc_name,
25+
"DpnpRulesArrayOperator_" + ufunc_name,
2926
(cls,),
3027
dict(key=op),
3128
)
@@ -35,36 +32,7 @@ def _install_operations(cls: NumpyRulesArrayOperator):
3532
class DpnpRulesArrayOperator(NumpyRulesArrayOperator):
3633
@property
3734
def ufunc(self):
38-
try:
39-
dpnpop = getattr(dpnp, self._op_map[self.key])
40-
npop = getattr(np, self._op_map[self.key])
41-
if not hasattr(dpnpop, "nin"):
42-
dpnpop.nin = npop.nin
43-
if not hasattr(dpnpop, "nout"):
44-
dpnpop.nout = npop.nout
45-
if not hasattr(dpnpop, "nargs"):
46-
dpnpop.nargs = dpnpop.nin + dpnpop.nout
47-
48-
# Check if the dpnp operation has a `types` attribute and if an
49-
# AttributeError gets raised then "monkey patch" the attribute from
50-
# numpy. If the attribute lookup raised a ValueError, it indicates
51-
# that dpnp could not be resolve the supported types for the
52-
# operation. Dpnp will fail to resolve the `types` if no SYCL
53-
# devices are available on the system. For such a scenario, we print
54-
# a user-level warning.
55-
try:
56-
dpnpop.types
57-
except ValueError:
58-
logging.exception(
59-
f"The types attribute for the {dpnpop} fuction could not "
60-
"be determined."
61-
)
62-
except AttributeError:
63-
dpnpop.types = npop.types
64-
dpnpop.is_dpnp_ufunc = True
65-
return dpnpop
66-
except:
67-
pass
35+
return getattr(dpnp, self._op_map[self.key])
6836

6937
@classmethod
7038
def install_operations(cls):

numba_dpex/dpnp_iface/dpnp_ufunc_db.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import copy
6-
import logging
76

87
import dpnp
98
import numpy as np
@@ -57,15 +56,9 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
5756
# variable is passed by value
5857
from numba.np.ufunc_db import _ufunc_db
5958

60-
failed_dpnpop_types_lst = []
6159
for ufuncop in dpnpdecl.supported_ufuncs:
6260
if ufuncop == "erf":
6361
op = getattr(dpnp, "erf")
64-
op.nin = 1
65-
op.nout = 1
66-
op.nargs = 2
67-
op.types = ["f->f", "d->d"]
68-
op.is_dpnp_ufunc = True
6962

7063
_unary_d_d = types.float64(types.float64)
7164
_unary_f_f = types.float32(types.float32)
@@ -76,31 +69,7 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
7669
else:
7770
dpnpop = getattr(dpnp, ufuncop)
7871
npop = getattr(np, ufuncop)
79-
if not hasattr(dpnpop, "nin"):
80-
dpnpop.nin = npop.nin
81-
if not hasattr(dpnpop, "nout"):
82-
dpnpop.nout = npop.nout
83-
if not hasattr(dpnpop, "nargs"):
84-
dpnpop.nargs = dpnpop.nin + dpnpop.nout
85-
86-
# Check if the dpnp operation has a `types` attribute and if an
87-
# AttributeError gets raised then "monkey patch" the attribute from
88-
# numpy. If the attribute lookup raised a ValueError, it indicates
89-
# that dpnp could not be resolve the supported types for the
90-
# operation. Dpnp will fail to resolve the `types` if no SYCL
91-
# devices are available on the system. For such a scenario, we log
92-
# dpnp operations for which the ValueError happened and print them
93-
# as a user-level warning. It is done this way so that the failure
94-
# to load the dpnpdecl registry due to the ValueError does not
95-
# impede a user from importing numba-dpex.
96-
try:
97-
dpnpop.types
98-
except ValueError:
99-
failed_dpnpop_types_lst.append(ufuncop)
100-
except AttributeError:
101-
dpnpop.types = npop.types
102-
103-
dpnpop.is_dpnp_ufunc = True
72+
10473
cp = copy.copy(_ufunc_db[npop])
10574
ufunc_db.update({dpnpop: cp})
10675
for key in list(ufunc_db[dpnpop].keys()):
@@ -111,13 +80,3 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
11180
or "D->" in key
11281
):
11382
ufunc_db[dpnpop].pop(key)
114-
115-
if failed_dpnpop_types_lst:
116-
try:
117-
getattr(dpnp, failed_dpnpop_types_lst[0]).types
118-
except ValueError:
119-
ops = " ".join(failed_dpnpop_types_lst)
120-
logging.exception(
121-
"The types attribute for the following dpnp ops could not be "
122-
f"determined: {ops}"
123-
)

numba_dpex/numba_patches/patch_is_ufunc.py

Lines changed: 0 additions & 23 deletions
This file was deleted.
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# SPDX-FileCopyrightText: 2020 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import logging
6+
7+
import dpnp
8+
import numpy as np
9+
10+
from numba_dpex.core.typing import dpnpdecl
11+
12+
13+
def patch():
14+
patch_is_ufunc()
15+
patch_ufuncs()
16+
17+
18+
def patch_is_ufunc():
19+
"""Patches the numba.np.ufunc.array_exprs._is_ufunc function to make it
20+
possible to support dpnp universal functions (ufuncs).
21+
22+
The extra condition is the check for the "is_dpnp_ufunc" attribute to
23+
identify a non-NumPy ufunc.
24+
"""
25+
import numpy
26+
from numba.np.ufunc.dufunc import DUFunc
27+
28+
def _is_ufunc(func):
29+
return isinstance(func, (numpy.ufunc, DUFunc)) or hasattr(
30+
func, "is_dpnp_ufunc"
31+
)
32+
33+
from numba.np.ufunc import array_exprs
34+
35+
array_exprs._is_ufunc = _is_ufunc
36+
37+
38+
def patch_ufuncs():
39+
"""Patches dpnp user functions to make them compatible with numpy, so we
40+
can reuse numba's implementation.
41+
42+
It adds "nin", "nout", "nargs" and "is_dpnp_ufunc" attributes to ufuncs.
43+
"""
44+
failed_dpnpop_types_lst = []
45+
46+
op = getattr(dpnp, "erf")
47+
op.nin = 1
48+
op.nout = 1
49+
op.nargs = 2
50+
op.types = ["f->f", "d->d"]
51+
op.is_dpnp_ufunc = True
52+
53+
for ufuncop in dpnpdecl.supported_ufuncs:
54+
if ufuncop == "erf":
55+
continue
56+
57+
dpnpop = getattr(dpnp, ufuncop)
58+
npop = getattr(np, ufuncop)
59+
60+
if not hasattr(dpnpop, "nin"):
61+
dpnpop.nin = npop.nin
62+
if not hasattr(dpnpop, "nout"):
63+
dpnpop.nout = npop.nout
64+
if not hasattr(dpnpop, "nargs"):
65+
dpnpop.nargs = dpnpop.nin + dpnpop.nout
66+
67+
# Check if the dpnp operation has a `types` attribute and if an
68+
# AttributeError gets raised then "monkey patch" the attribute from
69+
# numpy. If the attribute lookup raised a ValueError, it indicates
70+
# that dpnp could not be resolve the supported types for the
71+
# operation. Dpnp will fail to resolve the `types` if no SYCL
72+
# devices are available on the system. For such a scenario, we log
73+
# dpnp operations for which the ValueError happened and print them
74+
# as a user-level warning. It is done this way so that the failure
75+
# to load the dpnpdecl registry due to the ValueError does not
76+
# impede a user from importing numba-dpex.
77+
try:
78+
dpnpop.types
79+
except ValueError:
80+
failed_dpnpop_types_lst.append(ufuncop)
81+
except AttributeError:
82+
dpnpop.types = npop.types
83+
84+
dpnpop.is_dpnp_ufunc = True
85+
86+
if len(failed_dpnpop_types_lst) > 0:
87+
try:
88+
getattr(dpnp, failed_dpnpop_types_lst[0]).types
89+
except ValueError:
90+
ops = " ".join(failed_dpnpop_types_lst)
91+
logging.exception(
92+
"The types attribute for the following dpnp ops could not be "
93+
f"determined: {ops}"
94+
)

0 commit comments

Comments
 (0)