Skip to content

Commit 5374eb4

Browse files
author
Diptorup Deb
authored
Merge pull request #1312 from IntelPython/experimental/cmp_exchg_ols
Atomic compare_exchange implementation
2 parents 6e9a574 + 69d0bfe commit 5374eb4

File tree

4 files changed

+271
-3
lines changed

4 files changed

+271
-3
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from .spv_fn_generator import (
3636
get_or_insert_atomic_load_fn,
37+
get_or_insert_spv_atomic_compare_exchange_fn,
3738
get_or_insert_spv_atomic_exchange_fn,
3839
get_or_insert_spv_atomic_store_fn,
3940
)
@@ -323,6 +324,108 @@ def _intrinsic_exchange_gen(context, builder, sig, args):
323324
return sig, _intrinsic_exchange_gen
324325

325326

327+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
328+
def _intrinsic_compare_exchange(
329+
ty_context, # pylint: disable=unused-argument
330+
ty_atomic_ref,
331+
ty_expected_ref,
332+
ty_desired,
333+
ty_expected_idx,
334+
):
335+
sig = types.boolean(
336+
ty_atomic_ref, ty_expected_ref, ty_desired, ty_expected_idx
337+
)
338+
339+
def _intrinsic_compare_exchange_gen(context, builder, sig, args):
340+
# get pointer to expected[expected_idx]
341+
data_attr = builder.extract_value(
342+
args[1],
343+
context.data_model_manager.lookup(sig.args[1]).get_field_position(
344+
"data"
345+
),
346+
)
347+
with builder.goto_entry_block():
348+
ptr_to_data_attr = builder.alloca(data_attr.type)
349+
builder.store(data_attr, ptr_to_data_attr)
350+
expected_ref_ptr = builder.gep(
351+
builder.load(ptr_to_data_attr), [args[3]]
352+
)
353+
354+
expected_arg = builder.load(expected_ref_ptr)
355+
desired_arg = args[2]
356+
atomic_ref_ptr = builder.extract_value(
357+
args[0],
358+
context.data_model_manager.lookup(sig.args[0]).get_field_position(
359+
"ref"
360+
),
361+
)
362+
# add conditional bitcast for atomic_ref pointer,
363+
# expected[expected_idx], and desired
364+
if sig.args[0].dtype == types.float32:
365+
atomic_ref_ptr = builder.bitcast(
366+
atomic_ref_ptr,
367+
llvmir.PointerType(
368+
llvmir.IntType(32), addrspace=sig.args[0].address_space
369+
),
370+
)
371+
expected_arg = builder.bitcast(expected_arg, llvmir.IntType(32))
372+
desired_arg = builder.bitcast(desired_arg, llvmir.IntType(32))
373+
elif sig.args[0].dtype == types.float64:
374+
atomic_ref_ptr = builder.bitcast(
375+
atomic_ref_ptr,
376+
llvmir.PointerType(
377+
llvmir.IntType(64), addrspace=sig.args[0].address_space
378+
),
379+
)
380+
expected_arg = builder.bitcast(expected_arg, llvmir.IntType(64))
381+
desired_arg = builder.bitcast(desired_arg, llvmir.IntType(64))
382+
383+
atomic_cmpexchg_fn_args = [
384+
atomic_ref_ptr,
385+
context.get_constant(
386+
types.int32, get_scope(sig.args[0].memory_scope)
387+
),
388+
context.get_constant(
389+
types.int32,
390+
get_memory_semantics_mask(sig.args[0].memory_order),
391+
),
392+
context.get_constant(
393+
types.int32,
394+
get_memory_semantics_mask(sig.args[0].memory_order),
395+
),
396+
desired_arg,
397+
expected_arg,
398+
]
399+
400+
ret_val = builder.call(
401+
get_or_insert_spv_atomic_compare_exchange_fn(
402+
context, builder.module, sig.args[0]
403+
),
404+
atomic_cmpexchg_fn_args,
405+
)
406+
407+
# compare_exchange returns the old value stored in AtomicRef object.
408+
# If the return value is same as expected, then compare_exchange
409+
# succeeded in replacing AtomicRef object with desired.
410+
# If the return value is not same as expected, then store return
411+
# value in expected.
412+
# In either case, return result of cmp instruction.
413+
is_cmp_exchg_success = builder.icmp_signed("==", ret_val, expected_arg)
414+
415+
with builder.if_else(is_cmp_exchg_success) as (then, otherwise):
416+
with then:
417+
pass
418+
with otherwise:
419+
if sig.args[0].dtype == types.float32:
420+
ret_val = builder.bitcast(ret_val, llvmir.FloatType())
421+
elif sig.args[0].dtype == types.float64:
422+
ret_val = builder.bitcast(ret_val, llvmir.DoubleType())
423+
builder.store(ret_val, expected_ref_ptr)
424+
return is_cmp_exchg_success
425+
426+
return sig, _intrinsic_compare_exchange_gen
427+
428+
326429
def _check_if_supported_ref(ref):
327430
supported = True
328431

@@ -689,3 +792,50 @@ def ol_exchange_impl(atomic_ref, val):
689792
return _intrinsic_exchange(atomic_ref, val)
690793

691794
return ol_exchange_impl
795+
796+
797+
@overload_method(
798+
AtomicRefType,
799+
"compare_exchange",
800+
target=DPEX_KERNEL_EXP_TARGET_NAME,
801+
)
802+
def ol_compare_exchange(
803+
atomic_ref,
804+
expected_ref,
805+
desired,
806+
expected_idx=0, # pylint: disable=unused-argument
807+
):
808+
"""SPIR-V overload for
809+
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.compare_exchange`.
810+
811+
Generates the same LLVM IR instruction as dpcpp for the
812+
`atomic_ref::compare_exchange_strong` function.
813+
814+
Raises:
815+
TypingError: When the dtype of the value passed to `compare_exchange`
816+
does not match the dtype of the AtomicRef type.
817+
"""
818+
819+
_check_if_supported_ref(expected_ref)
820+
821+
if atomic_ref.dtype != expected_ref.dtype:
822+
raise errors.TypingError(
823+
f"Type of value to compare_exchange: {expected_ref} does not match the "
824+
f"type of the reference: {atomic_ref.dtype} stored in the atomic ref."
825+
)
826+
827+
if atomic_ref.dtype != desired:
828+
raise errors.TypingError(
829+
f"Type of value to compare_exchange: {desired} does not match the "
830+
f"type of the reference: {atomic_ref.dtype} stored in the atomic ref."
831+
)
832+
833+
def ol_compare_exchange_impl(
834+
atomic_ref, expected_ref, desired, expected_idx=0
835+
):
836+
# pylint: disable=no-value-for-parameter
837+
return _intrinsic_compare_exchange(
838+
atomic_ref, expected_ref, desired, expected_idx
839+
)
840+
841+
return ol_compare_exchange_impl

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/spv_fn_generator.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,59 @@ def get_or_insert_spv_atomic_exchange_fn(context, module, atomic_ref_ty):
119119
fn.calling_convention = CC_SPIR_FUNC
120120

121121
return fn
122+
123+
124+
def get_or_insert_spv_atomic_compare_exchange_fn(
125+
context, module, atomic_ref_ty
126+
):
127+
"""
128+
Gets or inserts a declaration for a __spirv_AtomicCompareExchange call into the
129+
specified LLVM IR module.
130+
"""
131+
atomic_ref_dtype = atomic_ref_ty.dtype
132+
133+
# Spirv spec requires arguments and return type to be of integer types.
134+
# That is why the type is changed from float to int
135+
# while maintaining the bit-width.
136+
# During function call, bitcasting is performed
137+
# to adhere to this convention.
138+
if atomic_ref_dtype == types.float32:
139+
atomic_ref_dtype = types.uint32
140+
elif atomic_ref_dtype == types.float64:
141+
atomic_ref_dtype = types.uint64
142+
143+
ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
144+
ptr_type.addrspace = atomic_ref_ty.address_space
145+
atomic_cmpexchg_fn_retty = context.get_value_type(atomic_ref_dtype)
146+
147+
atomic_cmpexchg_fn_arg_types = [
148+
ptr_type,
149+
llvmir.IntType(32),
150+
llvmir.IntType(32),
151+
llvmir.IntType(32),
152+
context.get_value_type(atomic_ref_dtype),
153+
context.get_value_type(atomic_ref_dtype),
154+
]
155+
156+
mangled_fn_name = ext_itanium_mangler.mangle_ext(
157+
"__spirv_AtomicCompareExchange",
158+
[
159+
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
160+
"__spv.Scope.Flag",
161+
"__spv.MemorySemanticsMask.Flag",
162+
"__spv.MemorySemanticsMask.Flag",
163+
atomic_ref_dtype,
164+
atomic_ref_dtype,
165+
],
166+
)
167+
168+
fn = cgutils.get_or_insert_function(
169+
module,
170+
llvmir.FunctionType(
171+
atomic_cmpexchg_fn_retty, atomic_cmpexchg_fn_arg_types
172+
),
173+
mangled_fn_name,
174+
)
175+
fn.calling_convention = CC_SPIR_FUNC
176+
177+
return fn

numba_dpex/kernel_api/atomic_ref.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,29 @@ def exchange(self, val):
200200
old = self._ref[self._index]
201201
self._ref[self._index] = val
202202
return old
203+
204+
def compare_exchange(self, expected, desired, expected_idx=0):
205+
"""Compares the value of the object referenced by the AtomicRef
206+
against the value of ``expected[expected_idx]``.
207+
If the values are equal, replaces the value of the
208+
referenced object with the value of ``desired``.
209+
Otherwise assigns the original value of the
210+
referenced object to ``expected[expected_idx]``.
211+
212+
Args:
213+
expected : Array containing the expected value of the
214+
object referenced by the AtomicRef.
215+
desired : Value that replaces the value of the object
216+
referenced by the AtomicRef.
217+
expected_idx: Offset in `expected` array where the expected
218+
value of the object referenced by the AtomicRef is present.
219+
220+
Returns: Returns ``True`` if the comparison operation and
221+
replacement operation were successful.
222+
223+
"""
224+
if self._ref[self._index] == expected[expected_idx]:
225+
self._ref[self._index] = desired
226+
return True
227+
expected[expected_idx] = self._ref[self._index]
228+
return False

numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_load_store.py renamed to numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_load_store_cmp_exchg.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _kernel(a, b):
4545
and supported_dtype == dpnp.float64
4646
):
4747
pytest.xfail(
48-
"Atomic load, store, and exchange operations not working "
48+
"Atomic load and store operations not working "
4949
" for fp64 on OpenCL CPU"
5050
)
5151

@@ -82,8 +82,7 @@ def _kernel(a, b):
8282
and supported_dtype == dpnp.float64
8383
):
8484
pytest.xfail(
85-
"Atomic load, store, and exchange operations not working "
86-
" for fp64 on OpenCL CPU"
85+
"Atomic exchange operation not working " " for fp64 on OpenCL CPU"
8786
)
8887

8988
a_copy = dpnp.copy(a_orig)
@@ -100,6 +99,43 @@ def _kernel(a, b):
10099
assert b_copy[i] == a_orig[i]
101100

102101

102+
@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes)
103+
def test_compare_exchange_fns(supported_dtype):
104+
"""A test for compare exchange atomic functions."""
105+
106+
@dpex_exp.kernel
107+
def _kernel(b):
108+
b_ref = AtomicRef(b, index=1)
109+
b[0] = b_ref.compare_exchange(
110+
expected_ref=b, desired=b[3], expected_idx=2
111+
)
112+
113+
b = dpnp.arange(4, dtype=supported_dtype)
114+
115+
dev = b.sycl_device
116+
if (
117+
dev.backend == dpctl.backend_type.opencl
118+
and dev.device_type == dpctl.device_type.cpu
119+
and supported_dtype == dpnp.float64
120+
):
121+
pytest.xfail(
122+
"Atomic compare_exchange operation not working "
123+
" for fp64 on OpenCL CPU"
124+
)
125+
126+
dpex_exp.call_kernel(_kernel, dpex.Range(1), b)
127+
128+
# check for failure
129+
assert b[0] == 0
130+
assert b[2] == b[1]
131+
132+
dpex_exp.call_kernel(_kernel, dpex.Range(1), b)
133+
134+
# check for success
135+
assert b[0] == 1
136+
assert b[1] == b[3]
137+
138+
103139
def test_store_exchange_diff_types(store_exchange_fn):
104140
"""A negative test that verifies that a TypingError is raised if
105141
AtomicRef type and value are of different types.

0 commit comments

Comments
 (0)