Skip to content

Commit 582fca9

Browse files
author
Diptorup Deb
authored
Merge pull request #1215 from IntelPython/atomic_fence
Implementing atomic_fence in dpex.kernel
2 parents eaf4dd6 + 30cc205 commit 582fca9

File tree

8 files changed

+214
-37
lines changed

8 files changed

+214
-37
lines changed

numba_dpex/experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from numba_dpex.kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher
1414

1515
from ._kernel_dpcpp_spirv_overloads import (
16+
_atomic_fence_overloads,
1617
_atomic_ref_overloads,
1718
_group_barrier_overloads,
1819
_index_space_id_overloads,
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Provides overloads for functions included in kernel_api.atomic_fence
7+
that generate dpcpp SPIR-V LLVM IR intrinsic function calls.
8+
"""
9+
from llvmlite import ir as llvmir
10+
from numba.core import types
11+
from numba.extending import intrinsic, overload
12+
13+
from numba_dpex.kernel_api import atomic_fence
14+
15+
from ..target import DPEX_KERNEL_EXP_TARGET_NAME
16+
from ._spv_atomic_inst_helper import get_memory_semantics_mask, get_scope
17+
from .spv_fn_declarations import (
18+
_SUPPORT_CONVERGENT,
19+
get_or_insert_spv_atomic_fence_fn,
20+
)
21+
22+
23+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
24+
def _intrinsic_atomic_fence(
25+
ty_context, ty_spirv_mem_sem_mask, ty_spirv_scope
26+
): # pylint: disable=unused-argument
27+
28+
# Signature of `__spirv_MemoryBarrier` call that is
29+
# generated for atomic_fence. It takes two arguments -
30+
# scope and memory_semantics_mask.
31+
# All arguments have to be of type unsigned int32.
32+
sig = types.void(types.uint32, types.uint32)
33+
34+
def _intrinsic_atomic_fence_gen(
35+
context, builder, sig, args
36+
): # pylint: disable=unused-argument
37+
callinst = builder.call(
38+
get_or_insert_spv_atomic_fence_fn(builder.module),
39+
[
40+
builder.trunc(args[1], llvmir.IntType(32)), # scope
41+
builder.trunc(args[0], llvmir.IntType(32)), # semantics mask
42+
],
43+
)
44+
45+
if _SUPPORT_CONVERGENT: # pylint: disable=duplicate-code
46+
callinst.attributes.add("convergent")
47+
callinst.attributes.add("nounwind")
48+
49+
return (
50+
sig,
51+
_intrinsic_atomic_fence_gen,
52+
)
53+
54+
55+
@overload(
56+
atomic_fence,
57+
prefer_literal=True,
58+
target=DPEX_KERNEL_EXP_TARGET_NAME,
59+
)
60+
def ol_atomic_fence(memory_order, memory_scope):
61+
"""SPIR-V overload for
62+
:meth:`numba_dpex.kernel_api.atomic_fence`.
63+
64+
Generates the same LLVM IR instruction as DPC++ for the SYCL
65+
`atomic_fence` function.
66+
"""
67+
spirv_memory_semantics_mask = get_memory_semantics_mask(
68+
memory_order.literal_value
69+
)
70+
spirv_scope = get_scope(memory_scope.literal_value)
71+
72+
def ol_atomic_fence_impl(
73+
memory_order, memory_scope
74+
): # pylint: disable=unused-argument
75+
# pylint: disable=no-value-for-parameter
76+
return _intrinsic_atomic_fence(spirv_memory_semantics_mask, spirv_scope)
77+
78+
return ol_atomic_fence_impl

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
get_memory_semantics_mask,
3333
get_scope,
3434
)
35-
from .spv_atomic_fn_declarations import (
35+
from .spv_fn_declarations import (
3636
_SUPPORT_CONVERGENT,
3737
get_or_insert_atomic_load_fn,
3838
get_or_insert_spv_atomic_compare_exchange_fn,

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_group_barrier_overloads.py

Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,25 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
"""
6-
Provides overloads for functions included in kernel_iface.barrier that
6+
Provides overloads for functions included in kernel_api.barrier that
77
generate dpcpp SPIR-V LLVM IR intrinsic function calls.
88
"""
99

1010
from llvmlite import ir as llvmir
11-
from numba.core import cgutils, types
11+
from numba.core import types
1212
from numba.core.errors import TypingError
1313
from numba.extending import intrinsic, overload
1414

15-
from numba_dpex.core import itanium_mangler as ext_itanium_mangler
1615
from numba_dpex.core.types.kernel_api.index_space_ids import GroupType
1716
from numba_dpex.experimental.target import DPEX_KERNEL_EXP_TARGET_NAME
1817
from numba_dpex.kernel_api import group_barrier
1918
from numba_dpex.kernel_api.memory_enums import MemoryOrder, MemoryScope
2019

2120
from ._spv_atomic_inst_helper import get_memory_semantics_mask, get_scope
22-
from .spv_atomic_fn_declarations import _SUPPORT_CONVERGENT
21+
from .spv_fn_declarations import (
22+
_SUPPORT_CONVERGENT,
23+
get_or_insert_spv_group_barrier_fn,
24+
)
2325

2426

2527
def _get_memory_scope(fence_scope):
@@ -35,7 +37,7 @@ def _intrinsic_barrier(
3537
ty_mem_scope, # pylint: disable=unused-argument
3638
ty_spirv_mem_sem_mask, # pylint: disable=unused-argument
3739
):
38-
# Signature of `__spirv_control_barrier` call that is
40+
# Signature of `__spirv_ControlBarrier` call that is
3941
# generated for group_barrier. It takes three arguments -
4042
# exec_scope, memory_scope and memory_semantics_mask.
4143
# All arguments have to be of type unsigned int32.
@@ -44,42 +46,17 @@ def _intrinsic_barrier(
4446
def _intrinsic_barrier_codegen(
4547
context, builder, sig, args # pylint: disable=unused-argument
4648
):
47-
exec_scope_arg = builder.trunc(args[0], llvmir.IntType(32))
48-
mem_scope_arg = builder.trunc(args[1], llvmir.IntType(32))
49-
spirv_memory_semantics_mask_arg = builder.trunc(
50-
args[2], llvmir.IntType(32)
51-
)
52-
5349
fn_args = [
54-
exec_scope_arg,
55-
mem_scope_arg,
56-
spirv_memory_semantics_mask_arg,
50+
builder.trunc(args[0], llvmir.IntType(32)),
51+
builder.trunc(args[1], llvmir.IntType(32)),
52+
builder.trunc(args[2], llvmir.IntType(32)),
5753
]
5854

59-
mangled_fn_name = ext_itanium_mangler.mangle_ext(
60-
"__spirv_ControlBarrier", [types.uint32, types.uint32, types.uint32]
55+
callinst = builder.call(
56+
get_or_insert_spv_group_barrier_fn(builder.module), fn_args
6157
)
6258

63-
spirv_fn_arg_types = [
64-
llvmir.IntType(32),
65-
llvmir.IntType(32),
66-
llvmir.IntType(32),
67-
]
68-
69-
fn = cgutils.get_or_insert_function(
70-
builder.module,
71-
llvmir.FunctionType(llvmir.VoidType(), spirv_fn_arg_types),
72-
mangled_fn_name,
73-
)
74-
75-
if _SUPPORT_CONVERGENT:
76-
fn.attributes.add("convergent")
77-
fn.attributes.add("nounwind")
78-
fn.calling_convention = "spir_func"
79-
80-
callinst = builder.call(fn, fn_args)
81-
82-
if _SUPPORT_CONVERGENT:
59+
if _SUPPORT_CONVERGENT: # pylint: disable=duplicate-code
8360
callinst.attributes.add("convergent")
8461
callinst.attributes.add("nounwind")
8562

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/spv_atomic_fn_declarations.py renamed to numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/spv_fn_declarations.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,56 @@ def get_or_insert_spv_atomic_compare_exchange_fn(
226226
fn.attributes.add("nounwind")
227227

228228
return fn
229+
230+
231+
def get_or_insert_spv_group_barrier_fn(module):
232+
"""
233+
Gets or inserts a declaration for a __spirv_ControlBarrier call into the
234+
specified LLVM IR module.
235+
"""
236+
mangled_fn_name = ext_itanium_mangler.mangle_ext(
237+
"__spirv_ControlBarrier", [types.uint32, types.uint32, types.uint32]
238+
)
239+
240+
spirv_fn_arg_types = [
241+
llvmir.IntType(32),
242+
llvmir.IntType(32),
243+
llvmir.IntType(32),
244+
]
245+
246+
fn = cgutils.get_or_insert_function(
247+
module,
248+
llvmir.FunctionType(llvmir.VoidType(), spirv_fn_arg_types),
249+
mangled_fn_name,
250+
)
251+
fn.calling_convention = CC_SPIR_FUNC
252+
253+
if _SUPPORT_CONVERGENT:
254+
fn.attributes.add("convergent")
255+
fn.attributes.add("nounwind")
256+
257+
return fn
258+
259+
260+
def get_or_insert_spv_atomic_fence_fn(module):
261+
"""
262+
Gets or inserts a declaration for a __spirv_MemoryBarrier call into the
263+
specified LLVM IR module.
264+
"""
265+
mangled_fn_name = ext_itanium_mangler.mangle_ext(
266+
"__spirv_MemoryBarrier", [types.uint32, types.uint32]
267+
)
268+
269+
fn = cgutils.get_or_insert_function(
270+
module,
271+
llvmir.FunctionType(
272+
llvmir.VoidType(), [llvmir.IntType(32), llvmir.IntType(32)]
273+
),
274+
mangled_fn_name,
275+
)
276+
fn.calling_convention = CC_SPIR_FUNC
277+
if _SUPPORT_CONVERGENT:
278+
fn.attributes.add("convergent")
279+
fn.attributes.add("nounwind")
280+
281+
return fn

numba_dpex/kernel_api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
numba_dpex.
1010
"""
1111

12+
from .atomic_fence import atomic_fence
1213
from .atomic_ref import AtomicRef
1314
from .barrier import group_barrier
1415
from .index_space_ids import Group, Item, NdItem
@@ -18,6 +19,7 @@
1819

1920
__all__ = [
2021
"AddressSpace",
22+
"atomic_fence",
2123
"AtomicRef",
2224
"MemoryOrder",
2325
"MemoryScope",

numba_dpex/kernel_api/atomic_fence.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Python functions that simulate SYCL's atomic_fence primitives.
6+
"""
7+
8+
9+
def atomic_fence(memory_order, memory_scope): # pylint: disable=unused-argument
10+
"""The function for performing memory fence across all work-items.
11+
Modeled after ``sycl::atomic_fence`` function.
12+
It provides control over re-ordering of memory load
13+
and store operations. The ``atomic_fence`` function acts as a
14+
fence across all work-items and devices specified by a
15+
memory_scope argument.
16+
17+
Args:
18+
memory_order: The memory synchronization order.
19+
20+
memory_scope: The set of work-items and devices to which
21+
the memory ordering constraints apply.
22+
23+
"""
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import dpnp
2+
3+
import numba_dpex as dpex
4+
import numba_dpex.experimental as dpex_exp
5+
from numba_dpex.kernel_api import (
6+
AtomicRef,
7+
Item,
8+
MemoryOrder,
9+
MemoryScope,
10+
atomic_fence,
11+
)
12+
from numba_dpex.tests._helper import skip_windows
13+
14+
15+
# TODO: https://github.com/IntelPython/numba-dpex/issues/1308
16+
@skip_windows
17+
def test_atomic_fence():
18+
"""A test for atomic_fence function."""
19+
20+
@dpex_exp.kernel
21+
def _kernel(item: Item, a, b):
22+
i = item.get_id(0)
23+
24+
bref = AtomicRef(b)
25+
26+
if i == 1:
27+
a[i] += 1
28+
atomic_fence(MemoryOrder.RELEASE, MemoryScope.DEVICE)
29+
bref.store(1)
30+
elif i == 0:
31+
while not bref.load():
32+
continue
33+
atomic_fence(MemoryOrder.ACQUIRE, MemoryScope.DEVICE)
34+
for idx in range(1, a.size):
35+
a[0] += a[idx]
36+
37+
N = 2
38+
a = dpnp.ones(N, dtype=dpnp.int64)
39+
b = dpnp.zeros(1, dtype=dpnp.int64)
40+
41+
dpex_exp.call_kernel(_kernel, dpex.Range(N), a, b)
42+
43+
assert a[0] == N + 1

0 commit comments

Comments
 (0)