Skip to content

Commit 30cc205

Browse files
adarshyogaDiptorup Deb
authored andcommitted
Moving group barrier function declarations to spv_fn_declarations
1 parent 035fceb commit 30cc205

File tree

2 files changed

+14
-37
lines changed

2 files changed

+14
-37
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
get_memory_semantics_mask,
3333
get_scope,
3434
)
35-
from .spv_atomic_fn_declarations import (
35+
from .spv_fn_declarations import (
3636
_SUPPORT_CONVERGENT,
3737
get_or_insert_atomic_load_fn,
3838
get_or_insert_spv_atomic_compare_exchange_fn,

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_group_barrier_overloads.py

Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,25 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
"""
6-
Provides overloads for functions included in kernel_iface.barrier that
6+
Provides overloads for functions included in kernel_api.barrier that
77
generate dpcpp SPIR-V LLVM IR intrinsic function calls.
88
"""
99

1010
from llvmlite import ir as llvmir
11-
from numba.core import cgutils, types
11+
from numba.core import types
1212
from numba.core.errors import TypingError
1313
from numba.extending import intrinsic, overload
1414

15-
from numba_dpex.core import itanium_mangler as ext_itanium_mangler
1615
from numba_dpex.core.types.kernel_api.index_space_ids import GroupType
1716
from numba_dpex.experimental.target import DPEX_KERNEL_EXP_TARGET_NAME
1817
from numba_dpex.kernel_api import group_barrier
1918
from numba_dpex.kernel_api.memory_enums import MemoryOrder, MemoryScope
2019

2120
from ._spv_atomic_inst_helper import get_memory_semantics_mask, get_scope
22-
from .spv_atomic_fn_declarations import _SUPPORT_CONVERGENT
21+
from .spv_fn_declarations import (
22+
_SUPPORT_CONVERGENT,
23+
get_or_insert_spv_group_barrier_fn,
24+
)
2325

2426

2527
def _get_memory_scope(fence_scope):
@@ -35,7 +37,7 @@ def _intrinsic_barrier(
3537
ty_mem_scope, # pylint: disable=unused-argument
3638
ty_spirv_mem_sem_mask, # pylint: disable=unused-argument
3739
):
38-
# Signature of `__spirv_control_barrier` call that is
40+
# Signature of `__spirv_ControlBarrier` call that is
3941
# generated for group_barrier. It takes three arguments -
4042
# exec_scope, memory_scope and memory_semantics_mask.
4143
# All arguments have to be of type unsigned int32.
@@ -44,42 +46,17 @@ def _intrinsic_barrier(
4446
def _intrinsic_barrier_codegen(
4547
context, builder, sig, args # pylint: disable=unused-argument
4648
):
47-
exec_scope_arg = builder.trunc(args[0], llvmir.IntType(32))
48-
mem_scope_arg = builder.trunc(args[1], llvmir.IntType(32))
49-
spirv_memory_semantics_mask_arg = builder.trunc(
50-
args[2], llvmir.IntType(32)
51-
)
52-
5349
fn_args = [
54-
exec_scope_arg,
55-
mem_scope_arg,
56-
spirv_memory_semantics_mask_arg,
50+
builder.trunc(args[0], llvmir.IntType(32)),
51+
builder.trunc(args[1], llvmir.IntType(32)),
52+
builder.trunc(args[2], llvmir.IntType(32)),
5753
]
5854

59-
mangled_fn_name = ext_itanium_mangler.mangle_ext(
60-
"__spirv_ControlBarrier", [types.uint32, types.uint32, types.uint32]
55+
callinst = builder.call(
56+
get_or_insert_spv_group_barrier_fn(builder.module), fn_args
6157
)
6258

63-
spirv_fn_arg_types = [
64-
llvmir.IntType(32),
65-
llvmir.IntType(32),
66-
llvmir.IntType(32),
67-
]
68-
69-
fn = cgutils.get_or_insert_function(
70-
builder.module,
71-
llvmir.FunctionType(llvmir.VoidType(), spirv_fn_arg_types),
72-
mangled_fn_name,
73-
)
74-
75-
if _SUPPORT_CONVERGENT:
76-
fn.attributes.add("convergent")
77-
fn.attributes.add("nounwind")
78-
fn.calling_convention = "spir_func"
79-
80-
callinst = builder.call(fn, fn_args)
81-
82-
if _SUPPORT_CONVERGENT:
59+
if _SUPPORT_CONVERGENT: # pylint: disable=duplicate-code
8360
callinst.attributes.add("convergent")
8461
callinst.attributes.add("nounwind")
8562

0 commit comments

Comments
 (0)