Skip to content

Commit 4708ac7

Browse files
author
Diptorup Deb
authored
Merge pull request #1434 from IntelPython/fix/dpnp_nin_issue
Fix for dpex failure caused by addition of nin, nout and types
2 parents 1336f76 + c793313 commit 4708ac7

File tree

2 files changed

+69
-16
lines changed

2 files changed

+69
-16
lines changed

numba_dpex/core/typing/dpnpdecl.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import logging
6+
57
import dpnp
68
import numpy as np
79
from numba.core import types
@@ -34,14 +36,33 @@ class DpnpRulesArrayOperator(NumpyRulesArrayOperator):
3436
@property
3537
def ufunc(self):
3638
try:
37-
op = getattr(dpnp, self._op_map[self.key])
39+
dpnpop = getattr(dpnp, self._op_map[self.key])
3840
npop = getattr(np, self._op_map[self.key])
39-
op.nin = npop.nin
40-
op.nout = npop.nout
41-
op.nargs = npop.nargs
42-
op.types = npop.types
43-
op.is_dpnp_ufunc = True
44-
return op
41+
if not hasattr(dpnpop, "nin"):
42+
dpnpop.nin = npop.nin
43+
if not hasattr(dpnpop, "nout"):
44+
dpnpop.nout = npop.nout
45+
if not hasattr(dpnpop, "nargs"):
46+
dpnpop.nargs = dpnpop.nin + dpnpop.nout
47+
48+
# Check if the dpnp operation has a `types` attribute and if an
49+
# AttributeError gets raised then "monkey patch" the attribute from
50+
# numpy. If the attribute lookup raised a ValueError, it indicates
51+
# that dpnp could not be resolve the supported types for the
52+
# operation. Dpnp will fail to resolve the `types` if no SYCL
53+
# devices are available on the system. For such a scenario, we print
54+
# a user-level warning.
55+
try:
56+
dpnpop.types
57+
except ValueError:
58+
logging.exception(
59+
f"The types attribute for the {dpnpop} fuction could not "
60+
"be determined."
61+
)
62+
except AttributeError:
63+
dpnpop.types = npop.types
64+
dpnpop.is_dpnp_ufunc = True
65+
return dpnpop
4566
except:
4667
pass
4768

numba_dpex/dpnp_iface/dpnp_ufunc_db.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import copy
6+
import logging
67

78
import dpnp
89
import numpy as np
@@ -56,6 +57,7 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
5657
# variable is passed by value
5758
from numba.np.ufunc_db import _ufunc_db
5859

60+
failed_dpnpop_types_lst = []
5961
for ufuncop in dpnpdecl.supported_ufuncs:
6062
if ufuncop == "erf":
6163
op = getattr(dpnp, "erf")
@@ -72,20 +74,50 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
7274
"d->d": mathimpl.lower_ocl_impl[("erf", (_unary_d_d))],
7375
}
7476
else:
75-
op = getattr(dpnp, ufuncop)
77+
dpnpop = getattr(dpnp, ufuncop)
7678
npop = getattr(np, ufuncop)
77-
op.nin = npop.nin
78-
op.nout = npop.nout
79-
op.nargs = npop.nargs
80-
op.types = npop.types
81-
op.is_dpnp_ufunc = True
79+
if not hasattr(dpnpop, "nin"):
80+
dpnpop.nin = npop.nin
81+
if not hasattr(dpnpop, "nout"):
82+
dpnpop.nout = npop.nout
83+
if not hasattr(dpnpop, "nargs"):
84+
dpnpop.nargs = dpnpop.nin + dpnpop.nout
85+
86+
# Check if the dpnp operation has a `types` attribute and if an
87+
# AttributeError gets raised then "monkey patch" the attribute from
88+
# numpy. If the attribute lookup raised a ValueError, it indicates
89+
# that dpnp could not be resolve the supported types for the
90+
# operation. Dpnp will fail to resolve the `types` if no SYCL
91+
# devices are available on the system. For such a scenario, we log
92+
# dpnp operations for which the ValueError happened and print them
93+
# as a user-level warning. It is done this way so that the failure
94+
# to load the dpnpdecl registry due to the ValueError does not
95+
# impede a user from importing numba-dpex.
96+
try:
97+
dpnpop.types
98+
except ValueError:
99+
failed_dpnpop_types_lst.append(ufuncop)
100+
except AttributeError:
101+
dpnpop.types = npop.types
102+
103+
dpnpop.is_dpnp_ufunc = True
82104
cp = copy.copy(_ufunc_db[npop])
83-
ufunc_db.update({op: cp})
84-
for key in list(ufunc_db[op].keys()):
105+
ufunc_db.update({dpnpop: cp})
106+
for key in list(ufunc_db[dpnpop].keys()):
85107
if (
86108
"FF->" in key
87109
or "DD->" in key
88110
or "F->" in key
89111
or "D->" in key
90112
):
91-
ufunc_db[op].pop(key)
113+
ufunc_db[dpnpop].pop(key)
114+
115+
if failed_dpnpop_types_lst:
116+
try:
117+
getattr(dpnp, failed_dpnpop_types_lst[0]).types
118+
except ValueError:
119+
ops = " ".join(failed_dpnpop_types_lst)
120+
logging.exception(
121+
"The types attribute for the following dpnp ops could not be "
122+
f"determined: {ops}"
123+
)

0 commit comments

Comments
 (0)