3
3
# SPDX-License-Identifier: Apache-2.0
4
4
5
5
import copy
6
+ import logging
6
7
7
8
import dpnp
8
9
import numpy as np
@@ -56,6 +57,7 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
56
57
# variable is passed by value
57
58
from numba .np .ufunc_db import _ufunc_db
58
59
60
+ failed_dpnpop_types_lst = []
59
61
for ufuncop in dpnpdecl .supported_ufuncs :
60
62
if ufuncop == "erf" :
61
63
op = getattr (dpnp , "erf" )
@@ -72,20 +74,50 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
72
74
"d->d" : mathimpl .lower_ocl_impl [("erf" , (_unary_d_d ))],
73
75
}
74
76
else :
75
- op = getattr (dpnp , ufuncop )
77
+ dpnpop = getattr (dpnp , ufuncop )
76
78
npop = getattr (np , ufuncop )
77
- op .nin = npop .nin
78
- op .nout = npop .nout
79
- op .nargs = npop .nargs
80
- op .types = npop .types
81
- op .is_dpnp_ufunc = True
79
+ if not hasattr (dpnpop , "nin" ):
80
+ dpnpop .nin = npop .nin
81
+ if not hasattr (dpnpop , "nout" ):
82
+ dpnpop .nout = npop .nout
83
+ if not hasattr (dpnpop , "nargs" ):
84
+ dpnpop .nargs = dpnpop .nin + dpnpop .nout
85
+
86
+ # Check if the dpnp operation has a `types` attribute and if an
87
+ # AttributeError gets raised then "monkey patch" the attribute from
88
+ # numpy. If the attribute lookup raised a ValueError, it indicates
89
+ # that dpnp could not be resolve the supported types for the
90
+ # operation. Dpnp will fail to resolve the `types` if no SYCL
91
+ # devices are available on the system. For such a scenario, we log
92
+ # dpnp operations for which the ValueError happened and print them
93
+ # as a user-level warning. It is done this way so that the failure
94
+ # to load the dpnpdecl registry due to the ValueError does not
95
+ # impede a user from importing numba-dpex.
96
+ try :
97
+ dpnpop .types
98
+ except ValueError :
99
+ failed_dpnpop_types_lst .append (ufuncop )
100
+ except AttributeError :
101
+ dpnpop .types = npop .types
102
+
103
+ dpnpop .is_dpnp_ufunc = True
82
104
cp = copy .copy (_ufunc_db [npop ])
83
- ufunc_db .update ({op : cp })
84
- for key in list (ufunc_db [op ].keys ()):
105
+ ufunc_db .update ({dpnpop : cp })
106
+ for key in list (ufunc_db [dpnpop ].keys ()):
85
107
if (
86
108
"FF->" in key
87
109
or "DD->" in key
88
110
or "F->" in key
89
111
or "D->" in key
90
112
):
91
- ufunc_db [op ].pop (key )
113
+ ufunc_db [dpnpop ].pop (key )
114
+
115
+ if failed_dpnpop_types_lst :
116
+ try :
117
+ getattr (dpnp , failed_dpnpop_types_lst [0 ]).types
118
+ except ValueError :
119
+ ops = " " .join (failed_dpnpop_types_lst )
120
+ logging .exception (
121
+ "The types attribute for the following dpnp ops could not be "
122
+ f"determined: { ops } "
123
+ )
0 commit comments