Skip to content

Commit 2df6fab

Browse files
authored
Merge pull request #1367 from IntelPython/feature/accept_multidimentional_atomic_indexes
Accept multidimentional atomic indexes
2 parents c2cb985 + 13e6a9d commit 2df6fab

File tree

5 files changed

+57
-21
lines changed

5 files changed

+57
-21
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,27 @@
4141
)
4242

4343

44+
def _normalize_indices(context, builder, indty, inds, aryty):
45+
"""
46+
Convert integer indices into tuple of intp
47+
"""
48+
if indty in types.integer_domain:
49+
indty = types.UniTuple(dtype=indty, count=1)
50+
indices = [inds]
51+
else:
52+
indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
53+
indices = [
54+
context.cast(builder, i, t, types.intp) for t, i in zip(indty, indices)
55+
]
56+
57+
if aryty.ndim != len(indty):
58+
raise TypeError(
59+
f"indexing {aryty.ndim}-D array with {len(indty)}-D index"
60+
)
61+
62+
return indty, indices
63+
64+
4465
def _parse_enum_or_int_literal_(literal_int) -> int:
4566
"""Parse an instance of an enum class or numba.core.types.Literal to its
4667
actual int value.
@@ -208,23 +229,22 @@ def _intrinsic_atomic_ref_ctor(
208229
sig = ty_retty(ref, ty_index, ty_retty_ref)
209230

210231
def codegen(context, builder, sig, args):
211-
ref = args[0]
212-
index_pos = args[1]
232+
aryty, indty, _ = sig.args
233+
ary, inds, _ = args
213234

214-
dmm = context.data_model_manager
215-
data_attr_pos = dmm.lookup(sig.args[0]).get_field_position("data")
216-
data_attr = builder.extract_value(ref, data_attr_pos)
235+
indty, indices = _normalize_indices(
236+
context, builder, indty, inds, aryty
237+
)
217238

218-
with builder.goto_entry_block():
219-
ptr_to_data_attr = builder.alloca(data_attr.type)
220-
builder.store(data_attr, ptr_to_data_attr)
221-
ref_ptr_value = builder.gep(builder.load(ptr_to_data_attr), [index_pos])
239+
lary = context.make_array(aryty)(context, builder, ary)
240+
ref_ptr_value = cgutils.get_item_pointer(
241+
context, builder, aryty, lary, indices, wraparound=True
242+
)
222243

223244
atomic_ref_struct = cgutils.create_struct_proxy(ty_retty)(
224245
context, builder
225246
)
226-
ref_attr_pos = dmm.lookup(ty_retty).get_field_position("ref")
227-
atomic_ref_struct[ref_attr_pos] = ref_ptr_value
247+
atomic_ref_struct.ref = ref_ptr_value
228248
# pylint: disable=protected-access
229249
return atomic_ref_struct._getvalue()
230250

@@ -564,7 +584,7 @@ def _check_if_supported_ref(ref):
564584
)
565585
def ol_atomic_ref(
566586
ref,
567-
index=0,
587+
index,
568588
memory_order=MemoryOrder.RELAXED,
569589
memory_scope=MemoryScope.DEVICE,
570590
address_space=AddressSpace.GLOBAL,
@@ -635,7 +655,7 @@ def ol_atomic_ref(
635655

636656
def ol_atomic_ref_ctor_impl(
637657
ref,
638-
index=0,
658+
index,
639659
memory_order=MemoryOrder.RELAXED, # pylint: disable=unused-argument
640660
memory_scope=MemoryScope.DEVICE, # pylint: disable=unused-argument
641661
address_space=AddressSpace.GLOBAL, # pylint: disable=unused-argument

numba_dpex/kernel_api/atomic_ref.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class AtomicRef:
1818
def __init__( # pylint: disable=too-many-arguments
1919
self,
2020
ref,
21-
index=0,
21+
index,
2222
memory_order=MemoryOrder.RELAXED,
2323
memory_scope=MemoryScope.DEVICE,
2424
address_space=AddressSpace.GLOBAL,

numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_fence.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,16 @@
99
MemoryScope,
1010
atomic_fence,
1111
)
12-
from numba_dpex.tests._helper import skip_windows
1312

1413

15-
# TODO: https://github.com/IntelPython/numba-dpex/issues/1308
16-
@skip_windows
1714
def test_atomic_fence():
1815
"""A test for atomic_fence function."""
1916

2017
@dpex_exp.kernel
2118
def _kernel(item: Item, a, b):
2219
i = item.get_id(0)
2320

24-
bref = AtomicRef(b)
21+
bref = AtomicRef(b, index=0)
2522

2623
if i == 1:
2724
a[i] += 1

numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_ref.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import dpnp
6+
import numpy as np
67
import pytest
78
from numba.core.errors import TypingError
89

@@ -25,6 +26,27 @@ def atomic_ref_kernel(item: Item, a, b):
2526
pytest.fail("Unexpected execution failure")
2627

2728

29+
def test_atomic_ref_3_dim_compilation():
30+
@dpex_exp.kernel
31+
def atomic_ref_kernel(item: Item, a, b):
32+
i = item.get_id(0)
33+
v = AtomicRef(b, index=(1, 1, 1), address_space=AddressSpace.GLOBAL)
34+
v.fetch_add(a[i])
35+
36+
a = dpnp.ones(8)
37+
b = dpnp.zeros((2, 2, 2))
38+
39+
want = np.zeros((2, 2, 2))
40+
want[1, 1, 1] = a.size
41+
42+
try:
43+
dpex_exp.call_kernel(atomic_ref_kernel, Range(a.size), a, b)
44+
except Exception:
45+
pytest.fail("Unexpected execution failure")
46+
47+
assert np.array_equal(b.asnumpy(), want)
48+
49+
2850
def test_atomic_ref_compilation_failure():
2951
"""A negative test that verifies that a TypingError is raised if we try to
3052
create an AtomicRef in the local address space from a global address space

numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_barriers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33
import numba_dpex as dpex
44
import numba_dpex.experimental as dpex_exp
55
from numba_dpex.kernel_api import MemoryScope, NdItem, group_barrier
6-
from numba_dpex.tests._helper import skip_windows
76

87

9-
# TODO: https://github.com/IntelPython/numba-dpex/issues/1308
10-
@skip_windows
118
def test_group_barrier():
129
"""A test for group_barrier function."""
1310

0 commit comments

Comments
 (0)