Skip to content

Commit f59d9e8

Browse files
author
Diptorup Deb
authored
Merge pull request #1297 from IntelPython/experimental/ld_str_excg_ols
Implementations for atomic load, store and exchange operations
2 parents 57190e9 + fb916a2 commit f59d9e8

File tree

4 files changed

+454
-6
lines changed

4 files changed

+454
-6
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
get_memory_semantics_mask,
2525
get_scope,
2626
)
27+
from .spv_fn_generator import (
28+
get_or_insert_atomic_load_fn,
29+
get_or_insert_spv_atomic_exchange_fn,
30+
get_or_insert_spv_atomic_store_fn,
31+
)
2732

2833

2934
def _parse_enum_or_int_literal_(literal_int) -> int:
@@ -209,6 +214,107 @@ def codegen(context, builder, sig, args):
209214
)
210215

211216

217+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
218+
def _intrinsic_load(
219+
ty_context, ty_atomic_ref # pylint: disable=unused-argument
220+
):
221+
sig = ty_atomic_ref.dtype(ty_atomic_ref)
222+
223+
def _intrinsic_load_gen(context, builder, sig, args):
224+
atomic_ref_ty = sig.args[0]
225+
fn = get_or_insert_atomic_load_fn(
226+
context, builder.module, atomic_ref_ty
227+
)
228+
229+
spirv_memory_semantics_mask = get_memory_semantics_mask(
230+
atomic_ref_ty.memory_order
231+
)
232+
spirv_scope = get_scope(atomic_ref_ty.memory_scope)
233+
234+
fn_args = [
235+
builder.extract_value(
236+
args[0],
237+
context.data_model_manager.lookup(
238+
atomic_ref_ty
239+
).get_field_position("ref"),
240+
),
241+
context.get_constant(types.int32, spirv_scope),
242+
context.get_constant(types.int32, spirv_memory_semantics_mask),
243+
]
244+
245+
return builder.call(fn, fn_args)
246+
247+
return sig, _intrinsic_load_gen
248+
249+
250+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
251+
def _intrinsic_store(
252+
ty_context, ty_atomic_ref, ty_val
253+
): # pylint: disable=unused-argument
254+
sig = types.void(ty_atomic_ref, ty_val)
255+
256+
def _intrinsic_store_gen(context, builder, sig, args):
257+
atomic_ref_ty = sig.args[0]
258+
atomic_store_fn = get_or_insert_spv_atomic_store_fn(
259+
context, builder.module, atomic_ref_ty
260+
)
261+
262+
atomic_store_fn_args = [
263+
builder.extract_value(
264+
args[0],
265+
context.data_model_manager.lookup(
266+
atomic_ref_ty
267+
).get_field_position("ref"),
268+
),
269+
context.get_constant(
270+
types.int32, get_scope(atomic_ref_ty.memory_scope)
271+
),
272+
context.get_constant(
273+
types.int32,
274+
get_memory_semantics_mask(atomic_ref_ty.memory_order),
275+
),
276+
args[1],
277+
]
278+
279+
builder.call(atomic_store_fn, atomic_store_fn_args)
280+
281+
return sig, _intrinsic_store_gen
282+
283+
284+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
285+
def _intrinsic_exchange(
286+
ty_context, ty_atomic_ref, ty_val # pylint: disable=unused-argument
287+
):
288+
sig = ty_atomic_ref.dtype(ty_atomic_ref, ty_val)
289+
290+
def _intrinsic_exchange_gen(context, builder, sig, args):
291+
atomic_ref_ty = sig.args[0]
292+
atomic_exchange_fn = get_or_insert_spv_atomic_exchange_fn(
293+
context, builder.module, atomic_ref_ty
294+
)
295+
296+
atomic_exchange_fn_args = [
297+
builder.extract_value(
298+
args[0],
299+
context.data_model_manager.lookup(
300+
atomic_ref_ty
301+
).get_field_position("ref"),
302+
),
303+
context.get_constant(
304+
types.int32, get_scope(atomic_ref_ty.memory_scope)
305+
),
306+
context.get_constant(
307+
types.int32,
308+
get_memory_semantics_mask(atomic_ref_ty.memory_order),
309+
),
310+
args[1],
311+
]
312+
313+
return builder.call(atomic_exchange_fn, atomic_exchange_fn_args)
314+
315+
return sig, _intrinsic_exchange_gen
316+
317+
212318
def _check_if_supported_ref(ref):
213319
supported = True
214320

@@ -516,3 +622,72 @@ def ol_fetch_xor_impl(atomic_ref, val):
516622
return _intrinsic_fetch_xor(atomic_ref, val)
517623

518624
return ol_fetch_xor_impl
625+
626+
627+
@overload_method(AtomicRefType, "load", target=DPEX_KERNEL_EXP_TARGET_NAME)
628+
def ol_load(atomic_ref): # pylint: disable=unused-argument
629+
"""SPIR-V overload for
630+
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.load`.
631+
632+
Generates the same LLVM IR instruction as dpcpp for the
633+
`atomic_ref::load` function.
634+
635+
"""
636+
637+
def ol_load_impl(atomic_ref):
638+
# pylint: disable=no-value-for-parameter
639+
return _intrinsic_load(atomic_ref)
640+
641+
return ol_load_impl
642+
643+
644+
@overload_method(AtomicRefType, "store", target=DPEX_KERNEL_EXP_TARGET_NAME)
645+
def ol_store(atomic_ref, val):
646+
"""SPIR-V overload for
647+
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.store`.
648+
649+
Generates the same LLVM IR instruction as dpcpp for the
650+
`atomic_ref::store` function.
651+
652+
Raises:
653+
TypingError: When the dtype of the value stored does not match the
654+
dtype of the AtomicRef type.
655+
"""
656+
657+
if atomic_ref.dtype != val:
658+
raise errors.TypingError(
659+
f"Type of value to store: {val} does not match the type of the "
660+
f"reference: {atomic_ref.dtype} stored in the atomic ref."
661+
)
662+
663+
def ol_store_impl(atomic_ref, val):
664+
# pylint: disable=no-value-for-parameter
665+
return _intrinsic_store(atomic_ref, val)
666+
667+
return ol_store_impl
668+
669+
670+
@overload_method(AtomicRefType, "exchange", target=DPEX_KERNEL_EXP_TARGET_NAME)
671+
def ol_exchange(atomic_ref, val):
672+
"""SPIR-V overload for
673+
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.exchange`.
674+
675+
Generates the same LLVM IR instruction as dpcpp for the
676+
`atomic_ref::exchange` function.
677+
678+
Raises:
679+
TypingError: When the dtype of the value passed to `exchange`
680+
does not match the dtype of the AtomicRef type.
681+
"""
682+
683+
if atomic_ref.dtype != val:
684+
raise errors.TypingError(
685+
f"Type of value to exchange: {val} does not match the type of the "
686+
f"reference: {atomic_ref.dtype} stored in the atomic ref."
687+
)
688+
689+
def ol_exchange_impl(atomic_ref, val):
690+
# pylint: disable=no-value-for-parameter
691+
return _intrinsic_exchange(atomic_ref, val)
692+
693+
return ol_exchange_impl
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Implements a set of helper functions to generate the LLVM IR for SPIR-V
7+
functions and their use inside an LLVM module.
8+
"""
9+
10+
from llvmlite import ir as llvmir
11+
from numba.core import cgutils, types
12+
13+
from numba_dpex.core import itanium_mangler as ext_itanium_mangler
14+
from numba_dpex.core.targets.kernel_target import CC_SPIR_FUNC
15+
16+
17+
def get_or_insert_atomic_load_fn(context, module, atomic_ref_ty):
18+
"""
19+
Gets or inserts a declaration for a __spirv_AtomicLoad call into the
20+
specified LLVM IR module.
21+
"""
22+
atomic_ref_dtype = atomic_ref_ty.dtype
23+
atomic_load_fn_retty = context.get_value_type(atomic_ref_dtype)
24+
ptr_type = atomic_load_fn_retty.as_pointer()
25+
ptr_type.addrspace = atomic_ref_ty.address_space
26+
atomic_load_fn_arg_types = [
27+
ptr_type,
28+
llvmir.IntType(32),
29+
llvmir.IntType(32),
30+
]
31+
mangled_fn_name = ext_itanium_mangler.mangle_ext(
32+
"__spirv_AtomicLoad",
33+
[
34+
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
35+
"__spv.Scope.Flag",
36+
"__spv.MemorySemanticsMask.Flag",
37+
],
38+
)
39+
40+
fn = cgutils.get_or_insert_function(
41+
module,
42+
llvmir.FunctionType(atomic_load_fn_retty, atomic_load_fn_arg_types),
43+
mangled_fn_name,
44+
)
45+
fn.calling_convention = CC_SPIR_FUNC
46+
47+
return fn
48+
49+
50+
def get_or_insert_spv_atomic_store_fn(context, module, atomic_ref_ty):
51+
"""
52+
Gets or inserts a declaration for a __spirv_AtomicStore call into the
53+
specified LLVM IR module.
54+
"""
55+
atomic_ref_dtype = atomic_ref_ty.dtype
56+
ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
57+
ptr_type.addrspace = atomic_ref_ty.address_space
58+
atomic_store_fn_retty = llvmir.VoidType()
59+
atomic_store_fn_arg_types = [
60+
ptr_type,
61+
llvmir.IntType(32),
62+
llvmir.IntType(32),
63+
context.get_value_type(atomic_ref_dtype),
64+
]
65+
66+
mangled_fn_name = ext_itanium_mangler.mangle_ext(
67+
"__spirv_AtomicStore",
68+
[
69+
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
70+
"__spv.Scope.Flag",
71+
"__spv.MemorySemanticsMask.Flag",
72+
atomic_ref_dtype,
73+
],
74+
)
75+
76+
fn = cgutils.get_or_insert_function(
77+
module,
78+
llvmir.FunctionType(atomic_store_fn_retty, atomic_store_fn_arg_types),
79+
mangled_fn_name,
80+
)
81+
fn.calling_convention = CC_SPIR_FUNC
82+
83+
return fn
84+
85+
86+
def get_or_insert_spv_atomic_exchange_fn(context, module, atomic_ref_ty):
87+
"""
88+
Gets or inserts a declaration for a __spirv_AtomicExchange call into the
89+
specified LLVM IR module.
90+
"""
91+
atomic_ref_dtype = atomic_ref_ty.dtype
92+
ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
93+
ptr_type.addrspace = atomic_ref_ty.address_space
94+
atomic_exchange_fn_retty = context.get_value_type(atomic_ref_ty.dtype)
95+
atomic_exchange_fn_arg_types = [
96+
ptr_type,
97+
llvmir.IntType(32),
98+
llvmir.IntType(32),
99+
context.get_value_type(atomic_ref_dtype),
100+
]
101+
102+
mangled_fn_name = ext_itanium_mangler.mangle_ext(
103+
"__spirv_AtomicExchange",
104+
[
105+
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
106+
"__spv.Scope.Flag",
107+
"__spv.MemorySemanticsMask.Flag",
108+
atomic_ref_dtype,
109+
],
110+
)
111+
112+
fn = cgutils.get_or_insert_function(
113+
module,
114+
llvmir.FunctionType(
115+
atomic_exchange_fn_retty, atomic_exchange_fn_arg_types
116+
),
117+
mangled_fn_name,
118+
)
119+
fn.calling_convention = CC_SPIR_FUNC
120+
121+
return fn

0 commit comments

Comments
 (0)