Skip to content

Commit ff5b714

Browse files
committed
fixing dpnp failure caused by addition of nin, nout and ntypes
1 parent 1336f76 commit ff5b714

File tree

1 file changed

+39
-9
lines changed

1 file changed

+39
-9
lines changed

numba_dpex/dpnp_iface/dpnp_ufunc_db.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import copy
6+
import logging
67

78
import dpnp
89
import numpy as np
@@ -56,6 +57,7 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
5657
# variable is passed by value
5758
from numba.np.ufunc_db import _ufunc_db
5859

60+
failed_dpnpop_types_lst = []
5961
for ufuncop in dpnpdecl.supported_ufuncs:
6062
if ufuncop == "erf":
6163
op = getattr(dpnp, "erf")
@@ -72,20 +74,48 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
7274
"d->d": mathimpl.lower_ocl_impl[("erf", (_unary_d_d))],
7375
}
7476
else:
75-
op = getattr(dpnp, ufuncop)
77+
dpnpop = getattr(dpnp, ufuncop)
7678
npop = getattr(np, ufuncop)
77-
op.nin = npop.nin
78-
op.nout = npop.nout
79-
op.nargs = npop.nargs
80-
op.types = npop.types
81-
op.is_dpnp_ufunc = True
79+
if not hasattr(dpnpop, "nin"):
80+
dpnpop.nin = npop.nin
81+
if not hasattr(dpnpop, "nout"):
82+
dpnpop.nout = npop.nout
83+
if not hasattr(dpnpop, "nargs"):
84+
dpnpop.nargs = dpnpop.nin + dpnpop.nout
85+
86+
# Check for `types` attribute for dpnp op.
87+
# AttributeError:
88+
# If the `types` attribute is not present for dpnp op,
89+
# use the `types` attribute from corresponding numpy op.
90+
# ValueError:
91+
# Store all dpnp ops that failed when `types` attribute
92+
# is present but failure occurs when read.
93+
# Log all failing dpnp outside this loop.
94+
try:
95+
dpnpop.types
96+
except ValueError:
97+
failed_dpnpop_types_lst.append(ufuncop)
98+
except AttributeError:
99+
dpnpop.types = npop.types
100+
101+
dpnpop.is_dpnp_ufunc = True
82102
cp = copy.copy(_ufunc_db[npop])
83-
ufunc_db.update({op: cp})
84-
for key in list(ufunc_db[op].keys()):
103+
ufunc_db.update({dpnpop: cp})
104+
for key in list(ufunc_db[dpnpop].keys()):
85105
if (
86106
"FF->" in key
87107
or "DD->" in key
88108
or "F->" in key
89109
or "D->" in key
90110
):
91-
ufunc_db[op].pop(key)
111+
ufunc_db[dpnpop].pop(key)
112+
113+
if failed_dpnpop_types_lst:
114+
try:
115+
getattr(dpnp, failed_dpnpop_types_lst[0]).types
116+
except ValueError:
117+
ops = " ".join(failed_dpnpop_types_lst)
118+
logging.exception(
119+
"The types attribute for the following dpnp ops could not be "
120+
f"determined: {ops}"
121+
)

0 commit comments

Comments
 (0)