Skip to content

Commit bded1c4

Browse files
committed
feat: add top level cast
1 parent 35656f1 commit bded1c4

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

python/datafusion/dataframe.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,19 @@ def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
245245
exprs_raw = [sort_or_default(expr) for expr in exprs]
246246
return DataFrame(self.df.sort(*exprs_raw))
247247

248+
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
249+
"""Cast all or a subset of columns to new dtype.
250+
251+
Args:
252+
mapping (dict[str, pa.DataType[Any]]): Mapped with column as key and column
253+
dtype as value.
254+
255+
Returns:
256+
DataFrame after casting columns
257+
"""
258+
exprs = [Expr.column(col).cast(dtype) for col, dtype in mapping.items()]
259+
return self.with_columns(exprs)
260+
248261
def limit(self, count: int, offset: int = 0) -> DataFrame:
249262
"""Return a new :py:class:`DataFrame` with a limited number of rows.
250263

python/tests/test_dataframe.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,15 @@ def test_with_columns(df):
236236
assert result.column(6) == pa.array([5, 7, 9])
237237

238238

239+
def test_cast(df):
240+
df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
241+
expected = pa.schema(
242+
[("a", pa.float16()), ("b", pa.list_(pa.uint32())), ("c", pa.int64())]
243+
)
244+
245+
assert df.schema() == expected
246+
247+
239248
def test_with_column_renamed(df):
240249
df = df.with_column("c", column("a") + column("b")).with_column_renamed("c", "sum")
241250

0 commit comments

Comments
 (0)