Skip to content

Commit eb5a48b

Browse files
author
Diptorup Deb
committed
Move ocl.mathimpl and ocl.mathdecl into kernel_api_impl.spirv.math
1 parent 049db98 commit eb5a48b

File tree

7 files changed

+380
-373
lines changed

7 files changed

+380
-373
lines changed

numba_dpex/dpnp_iface/dpnp_ufunc_db.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from numba.core import types
1010

1111
from numba_dpex.core.typing import dpnpdecl
12-
13-
from ..ocl import mathimpl
12+
from numba_dpex.kernel_api_impl.spirv.math import mathimpl
1413

1514
# A global instance of dpnp ufuncs that are supported by numba-dpex
1615
_dpnp_ufunc_db = None
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
# SPDX-FileCopyrightText: 2020 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
# pylint: skip-file
6+
7+
"""Typing declarations for all ``math`` stdlib function in SPIRVTypingContext.
8+
"""
9+
import math
10+
11+
from numba.core import types
12+
from numba.core.typing.templates import (
13+
AttributeTemplate,
14+
ConcreteTemplate,
15+
Registry,
16+
signature,
17+
)
18+
19+
registry = Registry()
20+
builtin_attr = registry.register_attr
21+
infer_global = registry.register_global
22+
23+
24+
@builtin_attr
25+
class MathModuleAttribute(AttributeTemplate):
26+
key = types.Module(math)
27+
28+
def resolve_fabs(self, mod):
29+
return types.Function(MathFabsFn)
30+
31+
def resolve_exp(self, mod):
32+
return types.Function(MathExpFn)
33+
34+
def resolve_expm1(self, mod):
35+
return types.Function(MathExpm1Fn)
36+
37+
def resolve_sqrt(self, mod):
38+
return types.Function(MathSqrtFn)
39+
40+
def resolve_log(self, mod):
41+
return types.Function(MathLogFn)
42+
43+
def resolve_log1p(self, mod):
44+
return types.Function(MathLog1pFn)
45+
46+
def resolve_log10(self, mod):
47+
return types.Function(MathLog10Fn)
48+
49+
def resolve_sin(self, mod):
50+
return types.Function(MathSinFn)
51+
52+
def resolve_cos(self, mod):
53+
return types.Function(MathCosFn)
54+
55+
def resolve_tan(self, mod):
56+
return types.Function(MathTanFn)
57+
58+
def resolve_sinh(self, mod):
59+
return types.Function(MathSinhFn)
60+
61+
def resolve_cosh(self, mod):
62+
return types.Function(MathCoshFn)
63+
64+
def resolve_tanh(self, mod):
65+
return types.Function(MathTanhFn)
66+
67+
def resolve_asin(self, mod):
68+
return types.Function(MathAsinFn)
69+
70+
def resolve_acos(self, mod):
71+
return types.Function(MathAcosFn)
72+
73+
def resolve_atan(self, mod):
74+
return types.Function(MathAtanFn)
75+
76+
def resolve_atan2(self, mod):
77+
return types.Function(MathAtan2Fn)
78+
79+
def resolve_asinh(self, mod):
80+
return types.Function(MathAsinhFn)
81+
82+
def resolve_acosh(self, mod):
83+
return types.Function(MathAcoshFn)
84+
85+
def resolve_atanh(self, mod):
86+
return types.Function(MathAtanhFn)
87+
88+
def resolve_pi(self, mod):
89+
return types.float64
90+
91+
def resolve_e(self, mod):
92+
return types.float64
93+
94+
def resolve_floor(self, mod):
95+
return types.Function(MathFloorFn)
96+
97+
def resolve_ceil(self, mod):
98+
return types.Function(MathCeilFn)
99+
100+
def resolve_trunc(self, mod):
101+
return types.Function(MathTruncFn)
102+
103+
def resolve_isnan(self, mod):
104+
return types.Function(MathIsnanFn)
105+
106+
def resolve_isinf(self, mod):
107+
return types.Function(MathIsinfFn)
108+
109+
def resolve_degrees(self, mod):
110+
return types.Function(MathDegreesFn)
111+
112+
def resolve_radians(self, mod):
113+
return types.Function(MathRadiansFn)
114+
115+
def resolve_copysign(self, mod):
116+
return types.Function(MathCopysignFn)
117+
118+
def resolve_fmod(self, mod):
119+
return types.Function(MathFmodFn)
120+
121+
def resolve_pow(self, mod):
122+
return types.Function(MathPowFn)
123+
124+
def resolve_erf(self, mod):
125+
return types.Function(MathErfFn)
126+
127+
def resolve_erfc(self, mod):
128+
return types.Function(MathErfcFn)
129+
130+
def resolve_gamma(self, mod):
131+
return types.Function(MathGammaFn)
132+
133+
def resolve_lgamma(self, mod):
134+
return types.Function(MathLgammaFn)
135+
136+
137+
class UnaryMathFuncTemplate(ConcreteTemplate):
138+
cases = [
139+
signature(types.float64, types.int64),
140+
signature(types.float64, types.uint64),
141+
signature(types.float32, types.float32),
142+
signature(types.float64, types.float64),
143+
]
144+
145+
146+
class MathFabsFn(UnaryMathFuncTemplate):
147+
key = math.fabs
148+
149+
150+
class MathExpFn(UnaryMathFuncTemplate):
151+
key = math.exp
152+
153+
154+
class MathExpm1Fn(UnaryMathFuncTemplate):
155+
key = math.expm1
156+
157+
158+
class MathSqrtFn(UnaryMathFuncTemplate):
159+
key = math.sqrt
160+
161+
162+
class MathLogFn(UnaryMathFuncTemplate):
163+
key = math.log
164+
165+
166+
class MathLog1pFn(UnaryMathFuncTemplate):
167+
key = math.log1p
168+
169+
170+
class MathLog10Fn(UnaryMathFuncTemplate):
171+
key = math.log10
172+
173+
174+
class MathSinFn(UnaryMathFuncTemplate):
175+
key = math.sin
176+
177+
178+
class MathCosFn(UnaryMathFuncTemplate):
179+
key = math.cos
180+
181+
182+
class MathTanFn(UnaryMathFuncTemplate):
183+
key = math.tan
184+
185+
186+
class MathSinhFn(UnaryMathFuncTemplate):
187+
key = math.sinh
188+
189+
190+
class MathCoshFn(UnaryMathFuncTemplate):
191+
key = math.cosh
192+
193+
194+
class MathTanhFn(UnaryMathFuncTemplate):
195+
key = math.tanh
196+
197+
198+
class MathAsinFn(UnaryMathFuncTemplate):
199+
key = math.asin
200+
201+
202+
class MathAcosFn(UnaryMathFuncTemplate):
203+
key = math.acos
204+
205+
206+
class MathAtanFn(UnaryMathFuncTemplate):
207+
key = math.atan
208+
209+
210+
class MathAtan2Fn(ConcreteTemplate):
211+
key = math.atan2
212+
cases = [
213+
signature(types.float64, types.int64, types.int64),
214+
signature(types.float64, types.uint64, types.uint64),
215+
signature(types.float32, types.float32, types.float32),
216+
signature(types.float64, types.float64, types.float64),
217+
]
218+
219+
220+
class MathAsinhFn(UnaryMathFuncTemplate):
221+
key = math.asinh
222+
223+
224+
class MathAcoshFn(UnaryMathFuncTemplate):
225+
key = math.acosh
226+
227+
228+
class MathAtanhFn(UnaryMathFuncTemplate):
229+
key = math.atanh
230+
231+
232+
class MathFloorFn(UnaryMathFuncTemplate):
233+
key = math.floor
234+
235+
236+
class MathCeilFn(UnaryMathFuncTemplate):
237+
key = math.ceil
238+
239+
240+
class MathTruncFn(UnaryMathFuncTemplate):
241+
key = math.trunc
242+
243+
244+
class MathRadiansFn(UnaryMathFuncTemplate):
245+
key = math.radians
246+
247+
248+
class MathDegreesFn(UnaryMathFuncTemplate):
249+
key = math.degrees
250+
251+
252+
class MathErfFn(UnaryMathFuncTemplate):
253+
key = math.erf
254+
255+
256+
class MathErfcFn(UnaryMathFuncTemplate):
257+
key = math.erfc
258+
259+
260+
class MathGammaFn(UnaryMathFuncTemplate):
261+
key = math.gamma
262+
263+
264+
class MathLgammaFn(UnaryMathFuncTemplate):
265+
key = math.lgamma
266+
267+
268+
class BinaryMathFuncTemplate(ConcreteTemplate):
269+
cases = [
270+
signature(types.float32, types.float32, types.float32),
271+
signature(types.float64, types.float64, types.float64),
272+
]
273+
274+
275+
class MathCopysignFn(BinaryMathFuncTemplate):
276+
key = math.copysign
277+
278+
279+
class MathFmodFn(BinaryMathFuncTemplate):
280+
key = math.fmod
281+
282+
283+
class MathPowFn(ConcreteTemplate):
284+
key = math.pow
285+
cases = [
286+
signature(types.float32, types.float32, types.float32),
287+
signature(types.float64, types.float64, types.float64),
288+
signature(types.float32, types.float32, types.int32),
289+
signature(types.float64, types.float64, types.int32),
290+
]
291+
292+
293+
class MathIsnanFn(ConcreteTemplate):
294+
key = math.isnan
295+
cases = [
296+
signature(types.boolean, types.int64),
297+
signature(types.boolean, types.uint64),
298+
signature(types.boolean, types.float32),
299+
signature(types.boolean, types.float64),
300+
]
301+
302+
303+
class MathIsinfFn(ConcreteTemplate):
304+
key = math.isinf
305+
cases = [
306+
signature(types.boolean, types.int64),
307+
signature(types.boolean, types.uint64),
308+
signature(types.boolean, types.float32),
309+
signature(types.boolean, types.float64),
310+
]
311+
312+
313+
infer_global(math, types.Module(math))
314+
infer_global(math.fabs, types.Function(MathFabsFn))
315+
infer_global(math.exp, types.Function(MathExpFn))
316+
infer_global(math.expm1, types.Function(MathExpm1Fn))
317+
infer_global(math.sqrt, types.Function(MathSqrtFn))
318+
infer_global(math.log, types.Function(MathLogFn))
319+
infer_global(math.log1p, types.Function(MathLog1pFn))
320+
infer_global(math.log10, types.Function(MathLog10Fn))
321+
infer_global(math.sin, types.Function(MathSinFn))
322+
infer_global(math.cos, types.Function(MathCosFn))
323+
infer_global(math.tan, types.Function(MathTanFn))
324+
infer_global(math.sinh, types.Function(MathSinhFn))
325+
infer_global(math.cosh, types.Function(MathCoshFn))
326+
infer_global(math.tanh, types.Function(MathTanhFn))
327+
infer_global(math.asin, types.Function(MathAsinFn))
328+
infer_global(math.acos, types.Function(MathAcosFn))
329+
infer_global(math.atan, types.Function(MathAtanFn))
330+
infer_global(math.atan2, types.Function(MathAtan2Fn))
331+
infer_global(math.asinh, types.Function(MathAsinhFn))
332+
infer_global(math.acosh, types.Function(MathAcoshFn))
333+
infer_global(math.atanh, types.Function(MathAtanhFn))
334+
infer_global(math.floor, types.Function(MathFloorFn))
335+
infer_global(math.ceil, types.Function(MathCeilFn))
336+
infer_global(math.trunc, types.Function(MathTruncFn))
337+
infer_global(math.isnan, types.Function(MathIsnanFn))
338+
infer_global(math.isinf, types.Function(MathIsinfFn))
339+
infer_global(math.degrees, types.Function(MathDegreesFn))
340+
infer_global(math.radians, types.Function(MathRadiansFn))
341+
infer_global(math.copysign, types.Function(MathCopysignFn))
342+
infer_global(math.fmod, types.Function(MathFmodFn))
343+
infer_global(math.pow, types.Function(MathPowFn))
344+
infer_global(math.erf, types.Function(MathErfFn))
345+
infer_global(math.erfc, types.Function(MathErfcFn))
346+
infer_global(math.gamma, types.Function(MathGammaFn))
347+
infer_global(math.lgamma, types.Function(MathLgammaFn))

0 commit comments

Comments
 (0)