Skip to content

Commit 049db98

Browse files
author
Diptorup Deb
committed
Move ocl._declare_function into core.utils.cgutils.extra
1 parent a455e46 commit 049db98

File tree

3 files changed

+49
-56
lines changed

3 files changed

+49
-56
lines changed

numba_dpex/core/utils/cgutils_extra.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from llvmlite import ir as llvmir
66
from numba.core import cgutils, types
77

8+
from numba_dpex.core.utils.itanium_mangler import mangle_c
9+
810

911
class LLVMTypes:
1012
"""
@@ -21,6 +23,44 @@ class LLVMTypes:
2123
void_t = llvmir.VoidType()
2224

2325

26+
def declare_function(context, builder, name, sig, cargs, mangler=mangle_c):
27+
"""Insert declaration for a opencl builtin function.
28+
Uses the Itanium mangler.
29+
30+
Args
31+
----
32+
context: target context
33+
34+
builder: llvm builder
35+
36+
name: str
37+
symbol name
38+
39+
sig: signature
40+
function signature of the symbol being declared
41+
42+
cargs: sequence of str
43+
C type names for the arguments
44+
45+
mangler: a mangler function
46+
function to use to mangle the symbol
47+
48+
"""
49+
mod = builder.module
50+
if sig.return_type == types.void:
51+
llretty = llvmir.VoidType()
52+
else:
53+
llretty = context.get_value_type(sig.return_type)
54+
llargs = [context.get_value_type(t) for t in sig.args]
55+
fnty = llvmir.FunctionType(llretty, llargs)
56+
mangled = mangler(name, cargs)
57+
fn = cgutils.get_or_insert_function(mod, fnty, mangled)
58+
from numba_dpex import spirv_kernel_target
59+
60+
fn.calling_convention = spirv_kernel_target.CC_SPIR_FUNC
61+
return fn
62+
63+
2464
def get_llvm_type(context, type):
2565
"""Returns the LLVM Value corresponding to a Numba type.
2666

numba_dpex/ocl/_declare_function.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

numba_dpex/ocl/mathimpl.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,11 @@
99
from numba.core import types
1010
from numba.core.imputils import Registry
1111

12-
from numba_dpex.core.utils.itanium_mangler import mangle
13-
14-
from ._declare_function import _declare_function
12+
from numba_dpex.core.utils import cgutils_extra, itanium_mangler
1513

1614
registry = Registry()
1715
lower = registry.lower
1816

19-
# -----------------------------------------------------------------------------
20-
2117
_unary_b_f = types.int32(types.float32)
2218
_unary_b_d = types.int32(types.float64)
2319
_unary_f_f = types.float32(types.float32)
@@ -88,16 +84,21 @@
8884

8985

9086
# some functions may be named differently by the underlying math
91-
# library as oposed to the Python name.
87+
# library as opposed to the Python name.
9288
_lib_counterpart = {"gamma": "tgamma"}
9389

9490

9591
def _mk_fn_decl(name, decl_sig):
9692
sym = _lib_counterpart.get(name, name)
9793

9894
def core(context, builder, sig, args):
99-
fn = _declare_function(
100-
context, builder, sym, decl_sig, decl_sig.args, mangler=mangle
95+
fn = cgutils_extra.declare_function(
96+
context,
97+
builder,
98+
sym,
99+
decl_sig,
100+
decl_sig.args,
101+
mangler=itanium_mangler.mangle,
101102
)
102103
res = builder.call(fn, args)
103104
return context.cast(builder, res, decl_sig.return_type, sig.return_type)

0 commit comments

Comments
 (0)