|
3 | 3 | # SPDX-License-Identifier: Apache-2.0
|
4 | 4 |
|
5 | 5 | import dpctl
|
6 |
| -import numpy as np |
| 6 | +import dpnp as np |
7 | 7 | import pytest
|
8 | 8 |
|
9 | 9 | import numba_dpex as dpex
|
@@ -38,8 +38,11 @@ def fdtype(request):
|
38 | 38 |
|
39 | 39 | @pytest.fixture(params=list_of_i_dtypes + list_of_f_dtypes)
|
40 | 40 | def input_arrays(request):
|
41 |
| - a = np.array([0], request.param) |
42 |
| - return a, request.param |
| 41 | + def _inpute_arrays(filter_str): |
| 42 | + a = np.array([0], request.param, device=filter_str) |
| 43 | + return a, request.param |
| 44 | + |
| 45 | + return _inpute_arrays |
43 | 46 |
|
44 | 47 |
|
45 | 48 | list_of_op = [
|
@@ -72,11 +75,9 @@ def f(a):
|
72 | 75 | @pytest.mark.parametrize("filter_str", filter_strings)
|
73 | 76 | @skip_no_atomic_support
|
74 | 77 | def test_kernel_atomic_simple(filter_str, input_arrays, kernel_result_pair):
|
75 |
| - a, dtype = input_arrays |
| 78 | + a, dtype = input_arrays(filter_str) |
76 | 79 | kernel, expected = kernel_result_pair
|
77 |
| - device = dpctl.SyclDevice(filter_str) |
78 |
| - with dpctl.device_context(device): |
79 |
| - kernel[global_size, dpex.DEFAULT_LOCAL_SIZE](a) |
| 80 | + kernel[dpex.Range(global_size)](a) |
80 | 81 | assert a[0] == expected
|
81 | 82 |
|
82 | 83 |
|
@@ -114,15 +115,11 @@ def f(a):
|
114 | 115 | @pytest.mark.parametrize("filter_str", filter_strings)
|
115 | 116 | @skip_no_atomic_support
|
116 | 117 | def test_kernel_atomic_local(filter_str, input_arrays, return_list_of_op):
|
117 |
| - a, dtype = input_arrays |
| 118 | + a, dtype = input_arrays(filter_str) |
118 | 119 | op_type, expected = return_list_of_op
|
119 | 120 | f = get_func_local(op_type, dtype)
|
120 | 121 | kernel = dpex.kernel(f)
|
121 |
| - device = dpctl.SyclDevice(filter_str) |
122 |
| - with dpctl.device_context(device): |
123 |
| - gs = (N,) |
124 |
| - ls = (N,) |
125 |
| - kernel[gs, ls](a) |
| 122 | + kernel[dpex.Range(N), dpex.Range(N)](a) |
126 | 123 | assert a[0] == expected
|
127 | 124 |
|
128 | 125 |
|
@@ -161,10 +158,8 @@ def test_kernel_atomic_multi_dim(
|
161 | 158 | op_type, expected = return_list_of_op
|
162 | 159 | dim = return_list_of_dim
|
163 | 160 | kernel = get_kernel_multi_dim(op_type, len(dim))
|
164 |
| - a = np.zeros(dim, return_dtype) |
165 |
| - device = dpctl.SyclDevice(filter_str) |
166 |
| - with dpctl.device_context(device): |
167 |
| - kernel[global_size, dpex.DEFAULT_LOCAL_SIZE](a) |
| 161 | + a = np.zeros(dim, dtype=return_dtype, device=filter_str) |
| 162 | + kernel[dpex.Range(global_size)](a) |
168 | 163 | assert a[0] == expected
|
169 | 164 |
|
170 | 165 |
|
|
0 commit comments