3
3
# SPDX-License-Identifier: Apache-2.0
4
4
5
5
"""
6
- Provides overloads for functions included in kernel_iface .barrier that
6
+ Provides overloads for functions included in kernel_api .barrier that
7
7
generate dpcpp SPIR-V LLVM IR intrinsic function calls.
8
8
"""
9
9
10
10
from llvmlite import ir as llvmir
11
- from numba .core import cgutils , types
11
+ from numba .core import types
12
12
from numba .core .errors import TypingError
13
13
from numba .extending import intrinsic , overload
14
14
15
- from numba_dpex .core import itanium_mangler as ext_itanium_mangler
16
15
from numba_dpex .core .types .kernel_api .index_space_ids import GroupType
17
16
from numba_dpex .experimental .target import DPEX_KERNEL_EXP_TARGET_NAME
18
17
from numba_dpex .kernel_api import group_barrier
19
18
from numba_dpex .kernel_api .memory_enums import MemoryOrder , MemoryScope
20
19
21
20
from ._spv_atomic_inst_helper import get_memory_semantics_mask , get_scope
22
- from .spv_atomic_fn_declarations import _SUPPORT_CONVERGENT
21
+ from .spv_fn_declarations import (
22
+ _SUPPORT_CONVERGENT ,
23
+ get_or_insert_spv_group_barrier_fn ,
24
+ )
23
25
24
26
25
27
def _get_memory_scope (fence_scope ):
@@ -35,7 +37,7 @@ def _intrinsic_barrier(
35
37
ty_mem_scope , # pylint: disable=unused-argument
36
38
ty_spirv_mem_sem_mask , # pylint: disable=unused-argument
37
39
):
38
- # Signature of `__spirv_control_barrier ` call that is
40
+ # Signature of `__spirv_ControlBarrier ` call that is
39
41
# generated for group_barrier. It takes three arguments -
40
42
# exec_scope, memory_scope and memory_semantics_mask.
41
43
# All arguments have to be of type unsigned int32.
@@ -44,42 +46,17 @@ def _intrinsic_barrier(
44
46
def _intrinsic_barrier_codegen (
45
47
context , builder , sig , args # pylint: disable=unused-argument
46
48
):
47
- exec_scope_arg = builder .trunc (args [0 ], llvmir .IntType (32 ))
48
- mem_scope_arg = builder .trunc (args [1 ], llvmir .IntType (32 ))
49
- spirv_memory_semantics_mask_arg = builder .trunc (
50
- args [2 ], llvmir .IntType (32 )
51
- )
52
-
53
49
fn_args = [
54
- exec_scope_arg ,
55
- mem_scope_arg ,
56
- spirv_memory_semantics_mask_arg ,
50
+ builder . trunc ( args [ 0 ], llvmir . IntType ( 32 )) ,
51
+ builder . trunc ( args [ 1 ], llvmir . IntType ( 32 )) ,
52
+ builder . trunc ( args [ 2 ], llvmir . IntType ( 32 )) ,
57
53
]
58
54
59
- mangled_fn_name = ext_itanium_mangler . mangle_ext (
60
- "__spirv_ControlBarrier" , [ types . uint32 , types . uint32 , types . uint32 ]
55
+ callinst = builder . call (
56
+ get_or_insert_spv_group_barrier_fn ( builder . module ), fn_args
61
57
)
62
58
63
- spirv_fn_arg_types = [
64
- llvmir .IntType (32 ),
65
- llvmir .IntType (32 ),
66
- llvmir .IntType (32 ),
67
- ]
68
-
69
- fn = cgutils .get_or_insert_function (
70
- builder .module ,
71
- llvmir .FunctionType (llvmir .VoidType (), spirv_fn_arg_types ),
72
- mangled_fn_name ,
73
- )
74
-
75
- if _SUPPORT_CONVERGENT :
76
- fn .attributes .add ("convergent" )
77
- fn .attributes .add ("nounwind" )
78
- fn .calling_convention = "spir_func"
79
-
80
- callinst = builder .call (fn , fn_args )
81
-
82
- if _SUPPORT_CONVERGENT :
59
+ if _SUPPORT_CONVERGENT : # pylint: disable=duplicate-code
83
60
callinst .attributes .add ("convergent" )
84
61
callinst .attributes .add ("nounwind" )
85
62
0 commit comments