Skip to content

Commit cdb026f

Browse files
committed
Implement cast for XTensorVariables
1 parent be1330b commit cdb026f

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

pytensor/xtensor/math.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import sys
22

3+
import numpy as np
4+
35
import pytensor.scalar as ps
6+
from pytensor import config
47
from pytensor.scalar import ScalarOp
8+
from pytensor.scalar.basic import _cast_mapping
9+
from pytensor.xtensor.basic import as_xtensor
510
from pytensor.xtensor.vectorization import XElemwise
611

712

@@ -107,3 +112,25 @@ def _as_xelemwise(core_op: ScalarOp) -> XElemwise:
107112
true_divide = true_div = _as_xelemwise(ps.true_div)
108113
trunc = _as_xelemwise(ps.trunc)
109114
logical_xor = bitwise_xor = xor = _as_xelemwise(ps.xor)
115+
116+
_xelemwise_cast_op: dict[str, XElemwise] = {}
117+
118+
119+
def cast(x, dtype):
120+
if dtype == "floatX":
121+
dtype = config.floatX
122+
else:
123+
dtype = np.dtype(dtype).name
124+
125+
x = as_xtensor(x)
126+
if x.type.dtype == dtype:
127+
return x
128+
if x.type.dtype.startswith("complex") and not dtype.startswith("complex"):
129+
raise TypeError(
130+
"Casting from complex to real is ambiguous: consider"
131+
" real(), imag(), angle() or abs()"
132+
)
133+
134+
if dtype not in _xelemwise_cast_op:
135+
_xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype])
136+
return _xelemwise_cast_op[dtype](x)

pytensor/xtensor/type.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,9 @@ def copy(self, name: str | None = None):
401401
out.name = name # type: ignore
402402
return out
403403

404+
def astype(self, dtype):
405+
return px.math.cast(self, dtype)
406+
404407
def item(self):
405408
raise NotImplementedError("item not implemented for XTensorVariable")
406409

tests/xtensor/test_math.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pytensor.xtensor.basic import rename
1717
from pytensor.xtensor.math import add, exp
1818
from pytensor.xtensor.type import xtensor
19-
from tests.xtensor.util import xr_assert_allclose, xr_function
19+
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
2020

2121

2222
def test_all_scalar_ops_are_wrapped():
@@ -132,3 +132,21 @@ def test_multiple_constant():
132132
res = fn(x_test)
133133
expected_res = np.exp(x_test * 2) + 2
134134
np.testing.assert_allclose(res, expected_res)
135+
136+
137+
def test_cast():
138+
x = xtensor("x", shape=(2, 3), dims=("a", "b"), dtype="float32")
139+
yf64 = x.astype("float64")
140+
yi16 = x.astype("int16")
141+
ybool = x.astype("bool")
142+
143+
fn = xr_function([x], [yf64, yi16, ybool])
144+
x_test = xr_arange_like(x)
145+
res_f64, res_i16, res_bool = fn(x_test)
146+
xr_assert_allclose(res_f64, x_test.astype("float64"))
147+
xr_assert_allclose(res_i16, x_test.astype("int16"))
148+
xr_assert_allclose(res_bool, x_test.astype("bool"))
149+
150+
yc64 = x.astype("complex64")
151+
with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"):
152+
yc64.astype("float64")

0 commit comments

Comments
 (0)