Skip to content

Commit 18da234

Browse files
author
Diptorup Deb
authored
Merge pull request #1319 from IntelPython/fix/fetch_pfi_returns
Capturing return value from fetch_* atomic functions
2 parents cf4b631 + afb4b0d commit 18da234

File tree

3 files changed

+41
-11
lines changed

3 files changed

+41
-11
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def _parse_enum_or_int_literal_(literal_int) -> int:
6767
def _intrinsic_helper(
6868
ty_context, ty_atomic_ref, ty_val, op_str # pylint: disable=unused-argument
6969
):
70-
sig = types.void(ty_atomic_ref, ty_val)
70+
sig = ty_atomic_ref.dtype(ty_atomic_ref, ty_val)
7171

7272
def gen(context, builder, sig, args):
7373
atomic_ref_ty = sig.args[0]
@@ -126,7 +126,7 @@ def gen(context, builder, sig, args):
126126
args[1],
127127
]
128128

129-
builder.call(func, fn_args)
129+
return builder.call(func, fn_args)
130130

131131
return sig, gen
132132

@@ -145,7 +145,7 @@ def gen(context, builder, sig, args):
145145
args_lst[1] = builder.fneg(args[1])
146146
args = tuple(args_lst)
147147

148-
gen_fn(context, builder, sig, args)
148+
return gen_fn(context, builder, sig, args)
149149

150150
return gen
151151

numba_dpex/kernel_api/atomic_ref.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def fetch_add(self, val):
6868
Returns: The original value of the object referenced by the AtomicRef.
6969
7070
"""
71-
old = self._ref[self._index]
71+
old = self._ref[self._index].copy()
7272
self._ref[self._index] += val
7373
return old
7474

@@ -84,7 +84,7 @@ def fetch_sub(self, val):
8484
Returns: The original value of the object referenced by the AtomicRef.
8585
8686
"""
87-
old = self._ref[self._index]
87+
old = self._ref[self._index].copy()
8888
self._ref[self._index] -= val
8989
return old
9090

@@ -100,7 +100,7 @@ def fetch_min(self, val):
100100
Returns: The original value of the object referenced by the AtomicRef.
101101
102102
"""
103-
old = self._ref[self._index]
103+
old = self._ref[self._index].copy()
104104
self._ref[self._index] = min(old, val)
105105
return old
106106

@@ -116,7 +116,7 @@ def fetch_max(self, val):
116116
Returns: The original value of the object referenced by the AtomicRef.
117117
118118
"""
119-
old = self._ref[self._index]
119+
old = self._ref[self._index].copy()
120120
self._ref[self._index] = max(old, val)
121121
return old
122122

@@ -132,7 +132,7 @@ def fetch_and(self, val):
132132
Returns: The original value of the object referenced by the AtomicRef.
133133
134134
"""
135-
old = self._ref[self._index]
135+
old = self._ref[self._index].copy()
136136
self._ref[self._index] &= val
137137
return old
138138

@@ -148,7 +148,7 @@ def fetch_or(self, val):
148148
Returns: The original value of the object referenced by the AtomicRef.
149149
150150
"""
151-
old = self._ref[self._index]
151+
old = self._ref[self._index].copy()
152152
self._ref[self._index] |= val
153153
return old
154154

@@ -164,7 +164,7 @@ def fetch_xor(self, val):
164164
Returns: The original value of the object referenced by the AtomicRef.
165165
166166
"""
167-
old = self._ref[self._index]
167+
old = self._ref[self._index].copy()
168168
self._ref[self._index] ^= val
169169
return old
170170

@@ -197,7 +197,7 @@ def exchange(self, val):
197197
Returns: The original value of the object referenced by the AtomicRef.
198198
199199
"""
200-
old = self._ref[self._index]
200+
old = self._ref[self._index].copy()
201201
self._ref[self._index] = val
202202
return old
203203

numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_fetch_phi.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,36 @@ def _kernel(a, b, ref_index):
7272
assert b[ref_index] == b[ref_index + 1]
7373

7474

75+
def test_fetch_phi_retval(fetch_phi_fn):
76+
"""A test for all fetch_phi atomic functions."""
77+
78+
@dpex_exp.kernel
79+
def _kernel(a, b, c):
80+
i = dpex.get_global_id(0)
81+
v = AtomicRef(b, index=i)
82+
c[i] = getattr(v, fetch_phi_fn)(a[i])
83+
84+
N = 10
85+
a = dpnp.arange(N, dtype=dpnp.int32)
86+
b = dpnp.ones(N, dtype=dpnp.int32)
87+
c = dpnp.zeros(N, dtype=dpnp.int32)
88+
a_copy = dpnp.copy(a)
89+
b_copy = dpnp.copy(b)
90+
c_copy = dpnp.copy(c)
91+
92+
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b, c)
93+
94+
# Verify if the value returned by fetch_phi kernel
95+
# stored into `c` is same as the value returned
96+
# by fetch_phi python stored into `c_copy`
97+
for i in range(a.size):
98+
v = AtomicRef(b_copy, index=i)
99+
c_copy[i] = getattr(v, fetch_phi_fn)(a_copy[i])
100+
101+
for i in range(a.size):
102+
assert c[i] == c_copy[i]
103+
104+
75105
def test_fetch_phi_diff_types(fetch_phi_fn):
76106
"""A negative test that verifies that a TypingError is raised if
77107
AtomicRef type and value to be added are of different types.

0 commit comments

Comments
 (0)