|
16 | 16 | from pytensor.xtensor.basic import rename
|
17 | 17 | from pytensor.xtensor.math import add, exp
|
18 | 18 | from pytensor.xtensor.type import xtensor
|
19 |
| -from tests.xtensor.util import xr_assert_allclose, xr_function |
| 19 | +from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function |
20 | 20 |
|
21 | 21 |
|
22 | 22 | def test_all_scalar_ops_are_wrapped():
|
@@ -132,3 +132,21 @@ def test_multiple_constant():
|
132 | 132 | res = fn(x_test)
|
133 | 133 | expected_res = np.exp(x_test * 2) + 2
|
134 | 134 | np.testing.assert_allclose(res, expected_res)
|
| 135 | + |
| 136 | + |
| 137 | +def test_cast(): |
| 138 | + x = xtensor("x", shape=(2, 3), dims=("a", "b"), dtype="float32") |
| 139 | + yf64 = x.astype("float64") |
| 140 | + yi16 = x.astype("int16") |
| 141 | + ybool = x.astype("bool") |
| 142 | + |
| 143 | + fn = xr_function([x], [yf64, yi16, ybool]) |
| 144 | + x_test = xr_arange_like(x) |
| 145 | + res_f64, res_i16, res_bool = fn(x_test) |
| 146 | + xr_assert_allclose(res_f64, x_test.astype("float64")) |
| 147 | + xr_assert_allclose(res_i16, x_test.astype("int16")) |
| 148 | + xr_assert_allclose(res_bool, x_test.astype("bool")) |
| 149 | + |
| 150 | + yc64 = x.astype("complex64") |
| 151 | + with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"): |
| 152 | + yc64.astype("float64") |
0 commit comments