|
17 | 17 | import pytest
|
18 | 18 |
|
19 | 19 | import dpctl.tensor as dpt
|
| 20 | +from dpctl.tensor._type_utils import _to_device_supported_dtype |
20 | 21 | from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
|
21 | 22 |
|
22 | 23 | _all_dtypes = [
|
@@ -79,3 +80,45 @@ def test_diff_axis():
|
79 | 80 | expected_res[:, 1:, :], expected_res[:, :-1, :]
|
80 | 81 | )
|
81 | 82 | assert dpt.all(res == expected_res)
|
| 83 | + |
| 84 | + |
| 85 | +def test_diff_prepend_append_type_promotion(): |
| 86 | + get_queue_or_skip() |
| 87 | + |
| 88 | + dts = [ |
| 89 | + ("i1", "u1", "i8"), |
| 90 | + ("i1", "u8", "u1"), |
| 91 | + ("u4", "i4", "f4"), |
| 92 | + ("i8", "c8", "u8"), |
| 93 | + ] |
| 94 | + |
| 95 | + for _dts in dts: |
| 96 | + x = dpt.ones(10, dtype=_dts[1]) |
| 97 | + prepend = dpt.full(1, 2, dtype=_dts[0]) |
| 98 | + append = dpt.full(1, 3, dtype=_dts[2]) |
| 99 | + |
| 100 | + res = dpt.diff(x, prepend=prepend, append=append) |
| 101 | + assert res.dtype == _to_device_supported_dtype( |
| 102 | + dpt.result_type(prepend, x, append), |
| 103 | + x.sycl_queue.sycl_device, |
| 104 | + ) |
| 105 | + |
| 106 | + res = dpt.diff(x, prepend=prepend) |
| 107 | + assert res.dtype == _to_device_supported_dtype( |
| 108 | + dpt.result_type(prepend, x), |
| 109 | + x.sycl_queue.sycl_device, |
| 110 | + ) |
| 111 | + |
| 112 | + res = dpt.diff(x, append=append) |
| 113 | + assert res.dtype == _to_device_supported_dtype( |
| 114 | + dpt.result_type(x, append), |
| 115 | + x.sycl_queue.sycl_device, |
| 116 | + ) |
| 117 | + |
| 118 | + |
| 119 | +def test_diff_0d(): |
| 120 | + get_queue_or_skip() |
| 121 | + |
| 122 | + x = dpt.ones(()) |
| 123 | + with pytest.raises(ValueError): |
| 124 | + dpt.diff(x) |
0 commit comments