Skip to content

Commit eda778d

Browse files
Add tests for dpctl.tensor.not_equal
1 parent 17ee7a4 commit eda778d

File tree

1 file changed

+186
-0
lines changed

1 file changed

+186
-0
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import ctypes
18+
19+
import numpy as np
20+
import pytest
21+
22+
import dpctl
23+
import dpctl.tensor as dpt
24+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
25+
26+
from .utils import _all_dtypes, _compare_dtypes, _usm_types
27+
28+
29+
@pytest.mark.parametrize("op1_dtype", _all_dtypes)
30+
@pytest.mark.parametrize("op2_dtype", _all_dtypes)
31+
def test_not_equal_dtype_matrix(op1_dtype, op2_dtype):
32+
q = get_queue_or_skip()
33+
skip_if_dtype_not_supported(op1_dtype, q)
34+
skip_if_dtype_not_supported(op2_dtype, q)
35+
36+
sz = 127
37+
ar1 = dpt.ones(sz, dtype=op1_dtype)
38+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
39+
40+
r = dpt.not_equal(ar1, ar2)
41+
assert isinstance(r, dpt.usm_ndarray)
42+
expected_dtype = np.not_equal(
43+
np.zeros(1, dtype=op1_dtype), np.zeros(1, dtype=op2_dtype)
44+
).dtype
45+
assert _compare_dtypes(r.dtype, expected_dtype, sycl_queue=q)
46+
assert r.shape == ar1.shape
47+
assert (dpt.asnumpy(r) == np.full(r.shape, False, dtype=r.dtype)).all()
48+
assert r.sycl_queue == ar1.sycl_queue
49+
50+
ar3 = dpt.ones(sz, dtype=op1_dtype)
51+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
52+
53+
r = dpt.not_equal(ar3[::-1], ar4[::2])
54+
assert isinstance(r, dpt.usm_ndarray)
55+
expected_dtype = np.not_equal(
56+
np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)
57+
).dtype
58+
assert _compare_dtypes(r.dtype, expected_dtype, sycl_queue=q)
59+
assert r.shape == ar3.shape
60+
assert (dpt.asnumpy(r) == np.full(r.shape, False, dtype=r.dtype)).all()
61+
62+
63+
@pytest.mark.parametrize("op1_usm_type", _usm_types)
64+
@pytest.mark.parametrize("op2_usm_type", _usm_types)
65+
def test_not_equal_usm_type_matrix(op1_usm_type, op2_usm_type):
66+
get_queue_or_skip()
67+
68+
sz = 128
69+
ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type)
70+
ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type)
71+
72+
r = dpt.not_equal(ar1, ar2)
73+
assert isinstance(r, dpt.usm_ndarray)
74+
expected_usm_type = dpctl.utils.get_coerced_usm_type(
75+
(op1_usm_type, op2_usm_type)
76+
)
77+
assert r.usm_type == expected_usm_type
78+
79+
80+
def test_not_equal_order():
81+
get_queue_or_skip()
82+
83+
ar1 = dpt.ones((20, 20), dtype="i4", order="C")
84+
ar2 = dpt.ones((20, 20), dtype="i4", order="C")
85+
r1 = dpt.not_equal(ar1, ar2, order="C")
86+
assert r1.flags.c_contiguous
87+
r2 = dpt.not_equal(ar1, ar2, order="F")
88+
assert r2.flags.f_contiguous
89+
r3 = dpt.not_equal(ar1, ar2, order="A")
90+
assert r3.flags.c_contiguous
91+
r4 = dpt.not_equal(ar1, ar2, order="K")
92+
assert r4.flags.c_contiguous
93+
94+
ar1 = dpt.ones((20, 20), dtype="i4", order="F")
95+
ar2 = dpt.ones((20, 20), dtype="i4", order="F")
96+
r1 = dpt.not_equal(ar1, ar2, order="C")
97+
assert r1.flags.c_contiguous
98+
r2 = dpt.not_equal(ar1, ar2, order="F")
99+
assert r2.flags.f_contiguous
100+
r3 = dpt.not_equal(ar1, ar2, order="A")
101+
assert r3.flags.f_contiguous
102+
r4 = dpt.not_equal(ar1, ar2, order="K")
103+
assert r4.flags.f_contiguous
104+
105+
ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2]
106+
ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2]
107+
r4 = dpt.not_equal(ar1, ar2, order="K")
108+
assert r4.strides == (20, -1)
109+
110+
ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT
111+
ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT
112+
r4 = dpt.not_equal(ar1, ar2, order="K")
113+
assert r4.strides == (-1, 20)
114+
115+
116+
def test_not_equal_broadcasting():
117+
get_queue_or_skip()
118+
119+
m = dpt.ones((100, 5), dtype="i4")
120+
v = dpt.arange(5, dtype="i4")
121+
122+
r = dpt.not_equal(m, v)
123+
expected = np.full((100, 5), [True, False, True, True, True], dtype="?")
124+
125+
assert (dpt.asnumpy(r) == expected).all()
126+
127+
r2 = dpt.not_equal(v, m)
128+
assert (dpt.asnumpy(r2) == expected).all()
129+
130+
131+
@pytest.mark.parametrize("arr_dt", _all_dtypes)
132+
def test_not_equal_python_scalar(arr_dt):
133+
q = get_queue_or_skip()
134+
skip_if_dtype_not_supported(arr_dt, q)
135+
136+
X = dpt.zeros((10, 10), dtype=arr_dt, sycl_queue=q)
137+
py_zeros = (
138+
bool(0),
139+
int(0),
140+
float(0),
141+
complex(0),
142+
np.float32(0),
143+
ctypes.c_int(0),
144+
)
145+
for sc in py_zeros:
146+
R = dpt.not_equal(X, sc)
147+
assert isinstance(R, dpt.usm_ndarray)
148+
assert not dpt.all(R)
149+
R = dpt.not_equal(sc, X)
150+
assert isinstance(R, dpt.usm_ndarray)
151+
assert not dpt.all(R)
152+
153+
154+
class MockArray:
155+
def __init__(self, arr):
156+
self.data_ = arr
157+
158+
@property
159+
def __sycl_usm_array_interface__(self):
160+
return self.data_.__sycl_usm_array_interface__
161+
162+
163+
def test_not_equal_mock_array():
164+
get_queue_or_skip()
165+
a = dpt.arange(10)
166+
b = dpt.ones(10)
167+
c = MockArray(b)
168+
r = dpt.not_equal(a, c)
169+
assert isinstance(r, dpt.usm_ndarray)
170+
171+
172+
def test_not_equal_canary_mock_array():
173+
get_queue_or_skip()
174+
a = dpt.arange(10)
175+
176+
class Canary:
177+
def __init__(self):
178+
pass
179+
180+
@property
181+
def __sycl_usm_array_interface__(self):
182+
return None
183+
184+
c = Canary()
185+
with pytest.raises(ValueError):
186+
dpt.not_equal(a, c)

0 commit comments

Comments
 (0)