Skip to content

Commit eb42463

Browse files
committed
More tests for diff
1 parent c8dbb99 commit eb42463

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

dpctl/tests/test_tensor_diff.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pytest
1818

1919
import dpctl.tensor as dpt
20+
from dpctl.tensor._type_utils import _to_device_supported_dtype
2021
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2122

2223
_all_dtypes = [
@@ -79,3 +80,45 @@ def test_diff_axis():
7980
expected_res[:, 1:, :], expected_res[:, :-1, :]
8081
)
8182
assert dpt.all(res == expected_res)
83+
84+
85+
def test_diff_prepend_append_type_promotion():
86+
get_queue_or_skip()
87+
88+
dts = [
89+
("i1", "u1", "i8"),
90+
("i1", "u8", "u1"),
91+
("u4", "i4", "f4"),
92+
("i8", "c8", "u8"),
93+
]
94+
95+
for _dts in dts:
96+
x = dpt.ones(10, dtype=_dts[1])
97+
prepend = dpt.full(1, 2, dtype=_dts[0])
98+
append = dpt.full(1, 3, dtype=_dts[2])
99+
100+
res = dpt.diff(x, prepend=prepend, append=append)
101+
assert res.dtype == _to_device_supported_dtype(
102+
dpt.result_type(prepend, x, append),
103+
x.sycl_queue.sycl_device,
104+
)
105+
106+
res = dpt.diff(x, prepend=prepend)
107+
assert res.dtype == _to_device_supported_dtype(
108+
dpt.result_type(prepend, x),
109+
x.sycl_queue.sycl_device,
110+
)
111+
112+
res = dpt.diff(x, append=append)
113+
assert res.dtype == _to_device_supported_dtype(
114+
dpt.result_type(x, append),
115+
x.sycl_queue.sycl_device,
116+
)
117+
118+
119+
def test_diff_0d():
120+
get_queue_or_skip()
121+
122+
x = dpt.ones(())
123+
with pytest.raises(ValueError):
124+
dpt.diff(x)

0 commit comments

Comments
 (0)