diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 7bea0289b..a58634b53 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -28,7 +28,7 @@ functions as functions_internal, ) from datafusion.common import NullTreatment, RexType, DataTypeMap -from typing import Any, Optional +from typing import Any, Optional, Type import pyarrow as pa # The following are imported from the internal representation. We may choose to @@ -372,8 +372,25 @@ def is_not_null(self) -> Expr: """Returns ``True`` if this expression is not null.""" return Expr(self.expr.is_not_null()) - def cast(self, to: pa.DataType[Any]) -> Expr: + _to_pyarrow_types = { + float: pa.float64(), + int: pa.int64(), + str: pa.string(), + bool: pa.bool_(), + } + + def cast( + self, to: pa.DataType[Any] | Type[float] | Type[int] | Type[str] | Type[bool] + ) -> Expr: """Cast to a new data type.""" + if not isinstance(to, pa.DataType): + try: + to = self._to_pyarrow_types[to] + except KeyError: + raise TypeError( + "Expected instance of pyarrow.DataType or builtins.type" + ) + return Expr(self.expr.cast(to)) def rex_type(self) -> RexType: diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index fe092c456..e7e6d79e1 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -44,8 +44,9 @@ def df(): datetime(2020, 7, 2), ] ), + pa.array([False, True, True]), ], - names=["a", "b", "c", "d"], + names=["a", "b", "c", "d", "e"], ) return ctx.create_dataframe([[batch]]) @@ -63,15 +64,14 @@ def test_named_struct(df): ) expected = """DataFrame() -+-------+---+---------+------------------------------+ -| a | b | c | d | -+-------+---+---------+------------------------------+ -| Hello | 4 | hello | {a: Hello, b: 4, c: hello } | -| World | 5 | world | {a: World, b: 5, c: world } | -| ! | 6 | ! | {a: !, b: 6, c: !} | -+-------+---+---------+------------------------------+ ++-------+---+---------+------------------------------+-------+ +| a | b | c | d | e | ++-------+---+---------+------------------------------+-------+ +| Hello | 4 | hello | {a: Hello, b: 4, c: hello } | false | +| World | 5 | world | {a: World, b: 5, c: world } | true | +| ! | 6 | ! | {a: !, b: 6, c: !} | true | ++-------+---+---------+------------------------------+-------+ """.strip() - assert str(df) == expected @@ -978,3 +978,22 @@ def test_binary_string_functions(df): assert pa.array(result.column(1)).cast(pa.string()) == pa.array( ["Hello", "World", "!"] ) + + +@pytest.mark.parametrize( + "python_datatype, name, expected", + [ + pytest.param(bool, "e", pa.bool_(), id="bool"), + pytest.param(int, "b", pa.int64(), id="int"), + pytest.param(float, "b", pa.float64(), id="float"), + pytest.param(str, "b", pa.string(), id="str"), + ], +) +def test_cast(df, python_datatype, name: str, expected): + df = df.select( + column(name).cast(python_datatype).alias("actual"), + column(name).cast(expected).alias("expected"), + ) + result = df.collect() + result = result[0] + assert result.column(0) == result.column(1)