Skip to content

Commit 70c099a

Browse files
feat: add cast to DataFrame (#916)
* feat: add with_columns * feat: add top level cast * chore: improve docstring --------- Co-authored-by: Tim Saucer <[email protected]>
1 parent 7cca028 commit 70c099a

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

python/datafusion/dataframe.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from __future__ import annotations
2323

24+
2425
from typing import Any, Iterable, List, Literal, TYPE_CHECKING
2526
from datafusion.record_batch import RecordBatchStream
2627
from typing_extensions import deprecated
@@ -267,6 +268,18 @@ def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
267268
exprs_raw = [sort_or_default(expr) for expr in exprs]
268269
return DataFrame(self.df.sort(*exprs_raw))
269270

271+
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
272+
"""Cast one or more columns to a different data type.
273+
274+
Args:
275+
mapping: Mapped with column as key and column dtype as value.
276+
277+
Returns:
278+
DataFrame after casting columns
279+
"""
280+
exprs = [Expr.column(col).cast(dtype) for col, dtype in mapping.items()]
281+
return self.with_columns(exprs)
282+
270283
def limit(self, count: int, offset: int = 0) -> DataFrame:
271284
"""Return a new :py:class:`DataFrame` with a limited number of rows.
272285

python/tests/test_dataframe.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,15 @@ def test_with_columns(df):
247247
assert result.column(6) == pa.array([5, 7, 9])
248248

249249

250+
def test_cast(df):
251+
df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
252+
expected = pa.schema(
253+
[("a", pa.float16()), ("b", pa.list_(pa.uint32())), ("c", pa.int64())]
254+
)
255+
256+
assert df.schema() == expected
257+
258+
250259
def test_with_column_renamed(df):
251260
df = df.with_column("c", column("a") + column("b")).with_column_renamed("c", "sum")
252261

0 commit comments

Comments
 (0)