Skip to content

Commit 1f19e4e

Browse files
author
Diptorup Deb
authored
Merge pull request #1015 from adarshyoga/fix/atomic_tests
Making test_atomic_op testcase CFD compliant
2 parents 0915170 + aba4385 commit 1f19e4e

File tree

1 file changed

+12
-17
lines changed

1 file changed

+12
-17
lines changed

numba_dpex/tests/kernel_tests/test_atomic_op.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import dpctl
6-
import numpy as np
6+
import dpnp as np
77
import pytest
88

99
import numba_dpex as dpex
@@ -38,8 +38,11 @@ def fdtype(request):
3838

3939
@pytest.fixture(params=list_of_i_dtypes + list_of_f_dtypes)
4040
def input_arrays(request):
41-
a = np.array([0], request.param)
42-
return a, request.param
41+
def _inpute_arrays(filter_str):
42+
a = np.array([0], request.param, device=filter_str)
43+
return a, request.param
44+
45+
return _inpute_arrays
4346

4447

4548
list_of_op = [
@@ -72,11 +75,9 @@ def f(a):
7275
@pytest.mark.parametrize("filter_str", filter_strings)
7376
@skip_no_atomic_support
7477
def test_kernel_atomic_simple(filter_str, input_arrays, kernel_result_pair):
75-
a, dtype = input_arrays
78+
a, dtype = input_arrays(filter_str)
7679
kernel, expected = kernel_result_pair
77-
device = dpctl.SyclDevice(filter_str)
78-
with dpctl.device_context(device):
79-
kernel[global_size, dpex.DEFAULT_LOCAL_SIZE](a)
80+
kernel[dpex.Range(global_size)](a)
8081
assert a[0] == expected
8182

8283

@@ -114,15 +115,11 @@ def f(a):
114115
@pytest.mark.parametrize("filter_str", filter_strings)
115116
@skip_no_atomic_support
116117
def test_kernel_atomic_local(filter_str, input_arrays, return_list_of_op):
117-
a, dtype = input_arrays
118+
a, dtype = input_arrays(filter_str)
118119
op_type, expected = return_list_of_op
119120
f = get_func_local(op_type, dtype)
120121
kernel = dpex.kernel(f)
121-
device = dpctl.SyclDevice(filter_str)
122-
with dpctl.device_context(device):
123-
gs = (N,)
124-
ls = (N,)
125-
kernel[gs, ls](a)
122+
kernel[dpex.Range(N), dpex.Range(N)](a)
126123
assert a[0] == expected
127124

128125

@@ -161,10 +158,8 @@ def test_kernel_atomic_multi_dim(
161158
op_type, expected = return_list_of_op
162159
dim = return_list_of_dim
163160
kernel = get_kernel_multi_dim(op_type, len(dim))
164-
a = np.zeros(dim, return_dtype)
165-
device = dpctl.SyclDevice(filter_str)
166-
with dpctl.device_context(device):
167-
kernel[global_size, dpex.DEFAULT_LOCAL_SIZE](a)
161+
a = np.zeros(dim, dtype=return_dtype, device=filter_str)
162+
kernel[dpex.Range(global_size)](a)
168163
assert a[0] == expected
169164

170165

0 commit comments

Comments
 (0)