Skip to content

Commit d4b575f

Browse files
committed
Adds test_tensor_diff with tests for basic diff and axis keyword
1 parent 6664b18 commit d4b575f

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

dpctl/tests/test_tensor_diff.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2024 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
import dpctl.tensor as dpt
20+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
21+
22+
_all_dtypes = [
23+
"?",
24+
"i1",
25+
"u1",
26+
"i2",
27+
"u2",
28+
"i4",
29+
"u4",
30+
"i8",
31+
"u8",
32+
"f2",
33+
"f4",
34+
"f8",
35+
"c8",
36+
"c16",
37+
]
38+
39+
40+
@pytest.mark.parametrize("dt", _all_dtypes)
41+
def test_diff_basic(dt):
42+
q = get_queue_or_skip()
43+
skip_if_dtype_not_supported(dt, q)
44+
45+
x = dpt.asarray([9, 12, 7, 17, 10, 18, 15, 9, 8, 8], dtype=dt, sycl_queue=q)
46+
res = dpt.diff(x)
47+
op = dpt.not_equal if x.dtype is dpt.bool else dpt.subtract
48+
expected_res = op(x[1:], x[:-1])
49+
if dpt.dtype(dt).kind in "fc":
50+
assert dpt.allclose(res, expected_res)
51+
else:
52+
assert dpt.all(res == expected_res)
53+
54+
res = dpt.diff(x, n=5)
55+
expected_res = x
56+
for _ in range(5):
57+
expected_res = op(expected_res[1:], expected_res[:-1])
58+
if dpt.dtype(dt).kind in "fc":
59+
assert dpt.allclose(res, expected_res)
60+
else:
61+
assert dpt.all(res == expected_res)
62+
63+
64+
def test_diff_axis():
65+
get_queue_or_skip()
66+
67+
x = dpt.tile(
68+
dpt.asarray([9, 12, 7, 17, 10, 18, 15, 9, 8, 8], dtype="i4"), (3, 4, 1)
69+
)
70+
x[:, ::2, :] = 0
71+
res = dpt.diff(x, n=1, axis=1)
72+
expected_res = dpt.subtract(x[:, 1:, :], x[:, :-1, :])
73+
assert dpt.all(res == expected_res)
74+
75+
res = dpt.diff(x, n=3, axis=1)
76+
expected_res = x
77+
for _ in range(3):
78+
expected_res = dpt.subtract(
79+
expected_res[:, 1:, :], expected_res[:, :-1, :]
80+
)
81+
assert dpt.all(res == expected_res)

0 commit comments

Comments
 (0)