33# SPDX-License-Identifier: Apache-2.0
44
55"""
6- Provides overloads for functions included in kernel_iface .barrier that
6+ Provides overloads for functions included in kernel_api .barrier that
77generate dpcpp SPIR-V LLVM IR intrinsic function calls.
88"""
99
1010from llvmlite import ir as llvmir
11- from numba .core import cgutils , types
11+ from numba .core import types
1212from numba .core .errors import TypingError
1313from numba .extending import intrinsic , overload
1414
15- from numba_dpex .core import itanium_mangler as ext_itanium_mangler
1615from numba_dpex .core .types .kernel_api .index_space_ids import GroupType
1716from numba_dpex .experimental .target import DPEX_KERNEL_EXP_TARGET_NAME
1817from numba_dpex .kernel_api import group_barrier
1918from numba_dpex .kernel_api .memory_enums import MemoryOrder , MemoryScope
2019
2120from ._spv_atomic_inst_helper import get_memory_semantics_mask , get_scope
22- from .spv_atomic_fn_declarations import _SUPPORT_CONVERGENT
21+ from .spv_fn_declarations import (
22+ _SUPPORT_CONVERGENT ,
23+ get_or_insert_spv_group_barrier_fn ,
24+ )
2325
2426
2527def _get_memory_scope (fence_scope ):
@@ -35,7 +37,7 @@ def _intrinsic_barrier(
3537 ty_mem_scope , # pylint: disable=unused-argument
3638 ty_spirv_mem_sem_mask , # pylint: disable=unused-argument
3739):
38- # Signature of `__spirv_control_barrier ` call that is
40+ # Signature of `__spirv_ControlBarrier ` call that is
3941 # generated for group_barrier. It takes three arguments -
4042 # exec_scope, memory_scope and memory_semantics_mask.
4143 # All arguments have to be of type unsigned int32.
@@ -44,42 +46,17 @@ def _intrinsic_barrier(
4446 def _intrinsic_barrier_codegen (
4547 context , builder , sig , args # pylint: disable=unused-argument
4648 ):
47- exec_scope_arg = builder .trunc (args [0 ], llvmir .IntType (32 ))
48- mem_scope_arg = builder .trunc (args [1 ], llvmir .IntType (32 ))
49- spirv_memory_semantics_mask_arg = builder .trunc (
50- args [2 ], llvmir .IntType (32 )
51- )
52-
5349 fn_args = [
54- exec_scope_arg ,
55- mem_scope_arg ,
56- spirv_memory_semantics_mask_arg ,
50+ builder . trunc ( args [ 0 ], llvmir . IntType ( 32 )) ,
51+ builder . trunc ( args [ 1 ], llvmir . IntType ( 32 )) ,
52+ builder . trunc ( args [ 2 ], llvmir . IntType ( 32 )) ,
5753 ]
5854
59- mangled_fn_name = ext_itanium_mangler . mangle_ext (
60- "__spirv_ControlBarrier" , [ types . uint32 , types . uint32 , types . uint32 ]
55+ callinst = builder . call (
56+ get_or_insert_spv_group_barrier_fn ( builder . module ), fn_args
6157 )
6258
63- spirv_fn_arg_types = [
64- llvmir .IntType (32 ),
65- llvmir .IntType (32 ),
66- llvmir .IntType (32 ),
67- ]
68-
69- fn = cgutils .get_or_insert_function (
70- builder .module ,
71- llvmir .FunctionType (llvmir .VoidType (), spirv_fn_arg_types ),
72- mangled_fn_name ,
73- )
74-
75- if _SUPPORT_CONVERGENT :
76- fn .attributes .add ("convergent" )
77- fn .attributes .add ("nounwind" )
78- fn .calling_convention = "spir_func"
79-
80- callinst = builder .call (fn , fn_args )
81-
82- if _SUPPORT_CONVERGENT :
59+ if _SUPPORT_CONVERGENT : # pylint: disable=duplicate-code
8360 callinst .attributes .add ("convergent" )
8461 callinst .attributes .add ("nounwind" )
8562
0 commit comments