|
23 | 23 |
|
24 | 24 | from datafusion import SessionContext, column |
25 | 25 | from datafusion import functions as f |
26 | | -from datafusion import literal |
| 26 | +from datafusion import literal, string_literal |
27 | 27 |
|
28 | 28 | np.seterr(invalid="ignore") |
29 | 29 |
|
@@ -907,6 +907,22 @@ def test_temporal_functions(df): |
907 | 907 | assert result.column(10) == pa.array([31, 26, 2], type=pa.float64()) |
908 | 908 |
|
909 | 909 |
|
| 910 | +def test_arrow_cast(df): |
| 911 | + df = df.select( |
| 912 | + # we use `string_literal` to return utf8 instead of `literal` which returns |
| 913 | + # utf8view because datafusion.arrow_cast expects a utf8 instead of utf8view |
| 914 | + # https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179 |
| 915 | + f.arrow_cast(column("b"), string_literal("Float64")).alias("b_as_float"), |
| 916 | + f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"), |
| 917 | + ) |
| 918 | + result = df.collect() |
| 919 | + assert len(result) == 1 |
| 920 | + result = result[0] |
| 921 | + |
| 922 | + assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64()) |
| 923 | + assert result.column(1) == pa.array([4, 5, 6], type=pa.int32()) |
| 924 | + |
| 925 | + |
910 | 926 | def test_case(df): |
911 | 927 | df = df.select( |
912 | 928 | f.case(column("b")).when(literal(4), literal(10)).otherwise(literal(8)), |
|
0 commit comments