Skip to content

Commit 315b6e7

Browse files
author
Diptorup Deb
authored
Merge pull request #1261 from IntelPython/experimental/more_fetch_phi_fns
Adds all fetch_* SPIR-V overload to experimental
2 parents 089402a + 3edf70a commit 315b6e7

File tree

3 files changed

+319
-10
lines changed

3 files changed

+319
-10
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ def gen(context, builder, sig, args):
6969
"--spirv-ext=+SPV_EXT_shader_atomic_float_add"
7070
]
7171

72+
context.extra_compile_options[LLVM_SPIRV_ARGS] = [
73+
"--spirv-ext=+SPV_EXT_shader_atomic_float_min_max"
74+
]
75+
7276
ptr_type = retty.as_pointer()
7377
ptr_type.addrspace = atomic_ref_ty.address_space
7478

@@ -118,6 +122,59 @@ def _intrinsic_fetch_add(ty_context, ty_atomic_ref, ty_val):
118122
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_add")
119123

120124

125+
def _atomic_sub_float_wrapper(gen_fn):
126+
def gen(context, builder, sig, args):
127+
# args is a tuple, which is immutable
128+
# covert tuple to list obj first before replacing arg[1]
129+
# with fneg and convert back to tuple again.
130+
args_lst = list(args)
131+
args_lst[1] = builder.fneg(args[1])
132+
args = tuple(args_lst)
133+
134+
gen_fn(context, builder, sig, args)
135+
136+
return gen
137+
138+
139+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
140+
def _intrinsic_fetch_sub(ty_context, ty_atomic_ref, ty_val):
141+
if ty_atomic_ref.dtype in (types.float32, types.float64):
142+
# dpcpp does not support ``__spirv_AtomicFSubEXT``. fetch_sub
143+
# for floats is implemented by negating the value and calling fetch_add.
144+
# For example, A.fetch_sub(A, val) is implemented as A.fetch_add(-val).
145+
sig, gen = _intrinsic_helper(
146+
ty_context, ty_atomic_ref, ty_val, "fetch_add"
147+
)
148+
return sig, _atomic_sub_float_wrapper(gen)
149+
150+
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_sub")
151+
152+
153+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
154+
def _intrinsic_fetch_min(ty_context, ty_atomic_ref, ty_val):
155+
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_min")
156+
157+
158+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
159+
def _intrinsic_fetch_max(ty_context, ty_atomic_ref, ty_val):
160+
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_max")
161+
162+
163+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
164+
def _intrinsic_fetch_and(ty_context, ty_atomic_ref, ty_val):
165+
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_and")
166+
167+
168+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
169+
def _intrinsic_fetch_or(ty_context, ty_atomic_ref, ty_val):
170+
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_or")
171+
172+
173+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
174+
def _intrinsic_fetch_xor(ty_context, ty_atomic_ref, ty_val):
175+
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_xor")
176+
177+
121178
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
122179
def _intrinsic_atomic_ref_ctor(
123180
ty_context, ref, ty_index, ty_retty_ref # pylint: disable=unused-argument
@@ -294,3 +351,168 @@ def ol_fetch_add_impl(atomic_ref, val):
294351
return _intrinsic_fetch_add(atomic_ref, val)
295352

296353
return ol_fetch_add_impl
354+
355+
356+
@overload_method(AtomicRefType, "fetch_sub", target=DPEX_KERNEL_EXP_TARGET_NAME)
357+
def ol_fetch_sub(atomic_ref, val):
358+
"""SPIR-V overload for
359+
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_sub`.
360+
361+
Generates the same LLVM IR instruction as dpcpp for the
362+
`atomic_ref::fetch_sub` function.
363+
364+
Raises:
365+
TypingError: When the dtype of the aggregator value does not match the
366+
dtype of the AtomicRef type.
367+
"""
368+
if atomic_ref.dtype != val:
369+
raise errors.TypingError(
370+
f"Type of value to sub: {val} does not match the type of the "
371+
f"reference: {atomic_ref.dtype} stored in the atomic ref."
372+
)
373+
374+
def ol_fetch_sub_impl(atomic_ref, val):
375+
# pylint: disable=no-value-for-parameter
376+
return _intrinsic_fetch_sub(atomic_ref, val)
377+
378+
return ol_fetch_sub_impl
379+
380+
381+
@overload_method(AtomicRefType, "fetch_min", target=DPEX_KERNEL_EXP_TARGET_NAME)
382+
def ol_fetch_min(atomic_ref, val):
383+
"""SPIR-V overload for
384+
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_min`.
385+
386+
Generates the same LLVM IR instruction as dpcpp for the
387+
`atomic_ref::fetch_min` function.
388+
389+
Raises:
390+
TypingError: When the dtype of the aggregator value does not match the
391+
dtype of the AtomicRef type.
392+
"""
393+
if atomic_ref.dtype != val:
394+
raise errors.TypingError(
395+
f"Type of value to find min: {val} does not match the type of the "
396+
f"reference: {atomic_ref.dtype} stored in the atomic ref."
397+
)
398+
399+
def ol_fetch_min_impl(atomic_ref, val):
400+
# pylint: disable=no-value-for-parameter
401+
return _intrinsic_fetch_min(atomic_ref, val)
402+
403+
return ol_fetch_min_impl
404+
405+
406+
@overload_method(AtomicRefType, "fetch_max", target=DPEX_KERNEL_EXP_TARGET_NAME)
407+
def ol_fetch_max(atomic_ref, val):
408+
"""SPIR-V overload for
409+
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_max`.
410+
411+
Generates the same LLVM IR instruction as dpcpp for the
412+
`atomic_ref::fetch_max` function.
413+
414+
Raises:
415+
TypingError: When the dtype of the aggregator value does not match the
416+
dtype of the AtomicRef type.
417+
"""
418+
if atomic_ref.dtype != val:
419+
raise errors.TypingError(
420+
f"Type of value to find max: {val} does not match the type of the "
421+
f"reference: {atomic_ref.dtype} stored in the atomic ref."
422+
)
423+
424+
def ol_fetch_max_impl(atomic_ref, val):
425+
# pylint: disable=no-value-for-parameter
426+
return _intrinsic_fetch_max(atomic_ref, val)
427+
428+
return ol_fetch_max_impl
429+
430+
431+
@overload_method(AtomicRefType, "fetch_and", target=DPEX_KERNEL_EXP_TARGET_NAME)
432+
def ol_fetch_and(atomic_ref, val):
433+
"""SPIR-V overload for
434+
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_and`.
435+
436+
Generates the same LLVM IR instruction as dpcpp for the
437+
`atomic_ref::fetch_and` function.
438+
439+
Raises:
440+
TypingError: When the dtype of the aggregator value does not match the
441+
dtype of the AtomicRef type.
442+
"""
443+
if atomic_ref.dtype != val:
444+
raise errors.TypingError(
445+
f"Type of value to and: {val} does not match the type of the "
446+
f"reference: {atomic_ref.dtype} stored in the atomic ref."
447+
)
448+
449+
if atomic_ref.dtype not in (types.int32, types.int64):
450+
raise errors.TypingError(
451+
"fetch_and operation only supported on int32 and int64 dtypes."
452+
)
453+
454+
def ol_fetch_and_impl(atomic_ref, val):
455+
# pylint: disable=no-value-for-parameter
456+
return _intrinsic_fetch_and(atomic_ref, val)
457+
458+
return ol_fetch_and_impl
459+
460+
461+
@overload_method(AtomicRefType, "fetch_or", target=DPEX_KERNEL_EXP_TARGET_NAME)
462+
def ol_fetch_or(atomic_ref, val):
463+
"""SPIR-V overload for
464+
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_or`.
465+
466+
Generates the same LLVM IR instruction as dpcpp for the
467+
`atomic_ref::fetch_or` function.
468+
469+
Raises:
470+
TypingError: When the dtype of the aggregator value does not match the
471+
dtype of the AtomicRef type.
472+
"""
473+
if atomic_ref.dtype != val:
474+
raise errors.TypingError(
475+
f"Type of value to or: {val} does not match the type of the "
476+
f"reference: {atomic_ref.dtype} stored in the atomic ref."
477+
)
478+
479+
if atomic_ref.dtype not in (types.int32, types.int64):
480+
raise errors.TypingError(
481+
"fetch_or operation only supported on int32 and int64 dtypes."
482+
)
483+
484+
def ol_fetch_or_impl(atomic_ref, val):
485+
# pylint: disable=no-value-for-parameter
486+
return _intrinsic_fetch_or(atomic_ref, val)
487+
488+
return ol_fetch_or_impl
489+
490+
491+
@overload_method(AtomicRefType, "fetch_xor", target=DPEX_KERNEL_EXP_TARGET_NAME)
492+
def ol_fetch_xor(atomic_ref, val):
493+
"""SPIR-V overload for
494+
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_xor`.
495+
496+
Generates the same LLVM IR instruction as dpcpp for the
497+
`atomic_ref::fetch_xor` function.
498+
499+
Raises:
500+
TypingError: When the dtype of the aggregator value does not match the
501+
dtype of the AtomicRef type.
502+
"""
503+
if atomic_ref.dtype != val:
504+
raise errors.TypingError(
505+
f"Type of value to xor: {val} does not match the type of the "
506+
f"reference: {atomic_ref.dtype} stored in the atomic ref."
507+
)
508+
509+
if atomic_ref.dtype not in (types.int32, types.int64):
510+
raise errors.TypingError(
511+
"fetch_xor operation only supported on int32 and int64 dtypes."
512+
)
513+
514+
def ol_fetch_xor_impl(atomic_ref, val):
515+
# pylint: disable=no-value-for-parameter
516+
return _intrinsic_fetch_xor(atomic_ref, val)
517+
518+
return ol_fetch_xor_impl

numba_dpex/spirv_generator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,10 @@ def finalize(self):
159159
# TODO: find better approach to set SPIRV compiler arguments. Workaround
160160
# against caching intrinsic that sets this argument.
161161
# https://github.com/IntelPython/numba-dpex/issues/1262
162-
llvm_spirv_args = ["--spirv-ext=+SPV_EXT_shader_atomic_float_add"]
162+
llvm_spirv_args = [
163+
"--spirv-ext=+SPV_EXT_shader_atomic_float_add",
164+
"--spirv-ext=+SPV_EXT_shader_atomic_float_min_max",
165+
]
163166
for key in list(self.context.extra_compile_options.keys()):
164167
if key == LLVM_SPIRV_ARGS:
165168
llvm_spirv_args = self.context.extra_compile_options[key]

numba_dpex/tests/experimental/kernel_iface/spv_overloads/test_atomic_fetch_phi.py

Lines changed: 93 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import dpnp
66
import pytest
7+
from numba.core.errors import TypingError
78

89
import numba_dpex as dpex
910
import numba_dpex.experimental as dpex_exp
@@ -14,30 +15,80 @@
1415
no_bool=True, no_float16=True, no_none=True, no_complex=True
1516
)
1617

18+
list_of_fetch_phi_funcs = [
19+
"fetch_add",
20+
"fetch_sub",
21+
"fetch_min",
22+
"fetch_max",
23+
"fetch_and",
24+
"fetch_or",
25+
"fetch_xor",
26+
]
27+
28+
29+
@pytest.fixture(params=list_of_fetch_phi_funcs)
30+
def fetch_phi_fn(request):
31+
return request.param
32+
1733

1834
@pytest.fixture(params=list_of_supported_dtypes)
1935
def input_arrays(request):
2036
# The size of input and out arrays to be used
2137
N = 10
22-
a = dpnp.ones(N, dtype=request.param)
23-
b = dpnp.zeros(N, dtype=request.param)
38+
a = dpnp.arange(N, dtype=request.param)
39+
b = dpnp.ones(N, dtype=request.param)
2440
return a, b
2541

2642

2743
@pytest.mark.parametrize("ref_index", [0, 5])
28-
def test_fetch_add(input_arrays, ref_index):
44+
def test_fetch_phi_fn(input_arrays, ref_index, fetch_phi_fn):
45+
"""A test for all fetch_phi atomic functions."""
46+
2947
@dpex_exp.kernel
30-
def atomic_ref_kernel(a, b, ref_index):
48+
def _kernel(a, b, ref_index):
3149
i = dpex.get_global_id(0)
3250
v = AtomicRef(b, index=ref_index)
33-
v.fetch_add(a[i])
51+
getattr(v, fetch_phi_fn)(a[i])
3452

3553
a, b = input_arrays
3654

37-
dpex_exp.call_kernel(atomic_ref_kernel, dpex.Range(10), a, b, ref_index)
55+
if (
56+
fetch_phi_fn in ["fetch_and", "fetch_or", "fetch_xor"]
57+
and issubclass(a.dtype.type, dpnp.floating)
58+
and issubclass(b.dtype.type, dpnp.floating)
59+
):
60+
# fetch_and, fetch_or, fetch_xor accept only int arguments.
61+
# test for TypingError when float arguments are passed.
62+
with pytest.raises(TypingError):
63+
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b, ref_index)
64+
else:
65+
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b, ref_index)
66+
# Verify that `a` accumulated at b[ref_index] by kernel
67+
# matches the `a` accumulated at b[ref_index+1] using Python
68+
for i in range(a.size):
69+
v = AtomicRef(b, index=ref_index + 1)
70+
getattr(v, fetch_phi_fn)(a[i])
71+
72+
assert b[ref_index] == b[ref_index + 1]
73+
74+
75+
def test_fetch_phi_diff_types(fetch_phi_fn):
76+
"""A negative test that verifies that a TypingError is raised if
77+
AtomicRef type and value to be added are of different types.
78+
"""
79+
80+
@dpex_exp.kernel
81+
def _kernel(a, b):
82+
i = dpex.get_global_id(0)
83+
v = AtomicRef(b, index=0)
84+
getattr(v, fetch_phi_fn)(a[i])
85+
86+
N = 10
87+
a = dpnp.ones(N, dtype=dpnp.float32)
88+
b = dpnp.zeros(N, dtype=dpnp.int32)
3889

39-
# Verify that `a` was accumulated at b[ref_index]
40-
assert b[ref_index] == 10
90+
with pytest.raises(TypingError):
91+
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b)
4192

4293

4394
@dpex_exp.kernel
@@ -54,7 +105,7 @@ def atomic_ref_1(a):
54105
v.fetch_add(a[i + 2])
55106

56107

57-
def test_spirv_compiler_flags():
108+
def test_spirv_compiler_flags_add():
58109
"""Check if float atomic flag is being populated from intrinsic for the
59110
second call.
60111
@@ -68,3 +119,36 @@ def test_spirv_compiler_flags():
68119

69120
assert a[0] == N - 1
70121
assert a[1] == N - 1
122+
123+
124+
@dpex_exp.kernel
125+
def atomic_max_0(a):
126+
i = dpex.get_global_id(0)
127+
v = AtomicRef(a, index=0)
128+
if i != 0:
129+
v.fetch_max(a[i])
130+
131+
132+
@dpex_exp.kernel
133+
def atomic_max_1(a):
134+
i = dpex.get_global_id(0)
135+
v = AtomicRef(a, index=0)
136+
if i != 0:
137+
v.fetch_max(a[i])
138+
139+
140+
def test_spirv_compiler_flags_max():
141+
"""Check if float atomic flag is being populated from intrinsic for the
142+
second call.
143+
144+
https://github.com/IntelPython/numba-dpex/issues/1262
145+
"""
146+
N = 10
147+
a = dpnp.arange(N, dtype=dpnp.float32)
148+
b = dpnp.arange(N, dtype=dpnp.float32)
149+
150+
dpex_exp.call_kernel(atomic_max_0, dpex.Range(N), a)
151+
dpex_exp.call_kernel(atomic_max_1, dpex.Range(N), b)
152+
153+
assert a[0] == N - 1
154+
assert b[0] == N - 1

0 commit comments

Comments
 (0)