Skip to content

Commit a19cb7d

Browse files
committed
Accept multidimentional atomic indexes
1 parent c2cb985 commit a19cb7d

File tree

2 files changed

+53
-11
lines changed

2 files changed

+53
-11
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,27 @@
4141
)
4242

4343

44+
def _normalize_indices(context, builder, indty, inds, aryty):
45+
"""
46+
Convert integer indices into tuple of intp
47+
"""
48+
if indty in types.integer_domain:
49+
indty = types.UniTuple(dtype=indty, count=1)
50+
indices = [inds]
51+
else:
52+
indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
53+
indices = [
54+
context.cast(builder, i, t, types.intp) for t, i in zip(indty, indices)
55+
]
56+
57+
if aryty.ndim != len(indty):
58+
raise TypeError(
59+
f"indexing {aryty.ndim}-D array with {len(indty)}-D index"
60+
)
61+
62+
return indty, indices
63+
64+
4465
def _parse_enum_or_int_literal_(literal_int) -> int:
4566
"""Parse an instance of an enum class or numba.core.types.Literal to its
4667
actual int value.
@@ -208,23 +229,22 @@ def _intrinsic_atomic_ref_ctor(
208229
sig = ty_retty(ref, ty_index, ty_retty_ref)
209230

210231
def codegen(context, builder, sig, args):
211-
ref = args[0]
212-
index_pos = args[1]
232+
aryty, indty, _ = sig.args
233+
ary, inds, _ = args
213234

214-
dmm = context.data_model_manager
215-
data_attr_pos = dmm.lookup(sig.args[0]).get_field_position("data")
216-
data_attr = builder.extract_value(ref, data_attr_pos)
235+
indty, indices = _normalize_indices(
236+
context, builder, indty, inds, aryty
237+
)
217238

218-
with builder.goto_entry_block():
219-
ptr_to_data_attr = builder.alloca(data_attr.type)
220-
builder.store(data_attr, ptr_to_data_attr)
221-
ref_ptr_value = builder.gep(builder.load(ptr_to_data_attr), [index_pos])
239+
lary = context.make_array(aryty)(context, builder, ary)
240+
ref_ptr_value = cgutils.get_item_pointer(
241+
context, builder, aryty, lary, indices, wraparound=True
242+
)
222243

223244
atomic_ref_struct = cgutils.create_struct_proxy(ty_retty)(
224245
context, builder
225246
)
226-
ref_attr_pos = dmm.lookup(ty_retty).get_field_position("ref")
227-
atomic_ref_struct[ref_attr_pos] = ref_ptr_value
247+
atomic_ref_struct.ref = ref_ptr_value
228248
# pylint: disable=protected-access
229249
return atomic_ref_struct._getvalue()
230250

numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_ref.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import dpnp
6+
import numpy as np
67
import pytest
78
from numba.core.errors import TypingError
89

@@ -25,6 +26,27 @@ def atomic_ref_kernel(item: Item, a, b):
2526
pytest.fail("Unexpected execution failure")
2627

2728

29+
def test_atomic_ref_3_dim_compilation():
30+
@dpex_exp.kernel
31+
def atomic_ref_kernel(item: Item, a, b):
32+
i = item.get_id(0)
33+
v = AtomicRef(b, index=(1, 1, 1), address_space=AddressSpace.GLOBAL)
34+
v.fetch_add(a[i])
35+
36+
a = dpnp.ones(8)
37+
b = dpnp.zeros((2, 2, 2))
38+
39+
want = np.zeros((2, 2, 2))
40+
want[1, 1, 1] = a.size
41+
42+
try:
43+
dpex_exp.call_kernel(atomic_ref_kernel, Range(a.size), a, b)
44+
except Exception:
45+
pytest.fail("Unexpected execution failure")
46+
47+
assert np.array_equal(b.asnumpy(), want)
48+
49+
2850
def test_atomic_ref_compilation_failure():
2951
"""A negative test that verifies that a TypingError is raised if we try to
3052
create an AtomicRef in the local address space from a global address space

0 commit comments

Comments
 (0)