2
2
#
3
3
# SPDX-License-Identifier: Apache-2.0
4
4
5
+ import dpctl
5
6
import dpnp
6
7
import pytest
7
8
from numba .core .errors import TypingError
@@ -21,7 +22,8 @@ def store_exchange_fn(request):
21
22
return request .param
22
23
23
24
24
- def test_load_store_fn ():
25
+ @pytest .mark .parametrize ("supported_dtype" , list_of_supported_dtypes )
26
+ def test_load_store_fn (supported_dtype ):
25
27
"""A test for load/store atomic functions."""
26
28
27
29
@dpex_exp .kernel
@@ -33,8 +35,19 @@ def _kernel(a, b):
33
35
a_ref .store (val )
34
36
35
37
N = 10
36
- a = dpnp .zeros (2 * N , dtype = dpnp .float32 )
37
- b = dpnp .arange (N , dtype = dpnp .float32 )
38
+ a = dpnp .zeros (2 * N , dtype = supported_dtype )
39
+ b = dpnp .arange (N , dtype = supported_dtype )
40
+
41
+ dev = a .sycl_device
42
+ if (
43
+ dev .backend == dpctl .backend_type .opencl
44
+ and dev .device_type == dpctl .device_type .cpu
45
+ and supported_dtype == dpnp .float64
46
+ ):
47
+ pytest .xfail (
48
+ "Atomic load, store, and exchange operations not working "
49
+ " for fp64 on OpenCL CPU"
50
+ )
38
51
39
52
dpex_exp .call_kernel (_kernel , dpex .Range (b .size ), a , b )
40
53
# Verify that `b[i]` loaded and stored into a[i] by kernel
@@ -48,7 +61,8 @@ def _kernel(a, b):
48
61
assert a [i ] == a [i + b .size ]
49
62
50
63
51
- def test_exchange_fn ():
64
+ @pytest .mark .parametrize ("supported_dtype" , list_of_supported_dtypes )
65
+ def test_exchange_fn (supported_dtype ):
52
66
"""A test for exchange atomic function."""
53
67
54
68
@dpex_exp .kernel
@@ -58,8 +72,19 @@ def _kernel(a, b):
58
72
b [i ] = v .exchange (b [i ])
59
73
60
74
N = 10
61
- a_orig = dpnp .zeros (2 * N , dtype = dpnp .float32 )
62
- b_orig = dpnp .arange (N , dtype = dpnp .float32 )
75
+ a_orig = dpnp .zeros (2 * N , dtype = supported_dtype )
76
+ b_orig = dpnp .arange (N , dtype = supported_dtype )
77
+
78
+ dev = a_orig .sycl_device
79
+ if (
80
+ dev .backend == dpctl .backend_type .opencl
81
+ and dev .device_type == dpctl .device_type .cpu
82
+ and supported_dtype == dpnp .float64
83
+ ):
84
+ pytest .xfail (
85
+ "Atomic load, store, and exchange operations not working "
86
+ " for fp64 on OpenCL CPU"
87
+ )
63
88
64
89
a_copy = dpnp .copy (a_orig )
65
90
b_copy = dpnp .copy (b_orig )
0 commit comments