Skip to content

Commit 35656f1

Browse files
committed
feat: add with_columns
1 parent cdec202 commit 35656f1

File tree

3 files changed

+76
-1
lines changed

3 files changed

+76
-1
lines changed

python/datafusion/dataframe.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from __future__ import annotations
2323

24-
from typing import Any, List, TYPE_CHECKING
24+
from typing import Any, Iterable, List, TYPE_CHECKING
2525
from datafusion.record_batch import RecordBatchStream
2626
from typing_extensions import deprecated
2727
from datafusion.plan import LogicalPlan, ExecutionPlan
@@ -160,6 +160,40 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
160160
"""
161161
return DataFrame(self.df.with_column(name, expr.expr))
162162

163+
def with_columns(
164+
self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr
165+
) -> DataFrame:
166+
"""Add an additional column to the DataFrame.
167+
168+
Args:
169+
*exprs: Name of the column to add.
170+
**named_exprs: Expression to compute the column.
171+
172+
Returns:
173+
DataFrame with the new column.
174+
"""
175+
176+
def _simplify_expression(
177+
*exprs: Expr | Iterable[Expr], **named_exprs: Expr
178+
) -> list[Expr]:
179+
expr_list = []
180+
for expr in exprs:
181+
if isinstance(expr, Expr):
182+
expr_list.append(expr.expr)
183+
elif isinstance(expr, Iterable):
184+
for inner_expr in expr:
185+
expr_list.append(inner_expr.expr)
186+
else:
187+
raise NotImplementedError
188+
if named_exprs:
189+
for alias, expr in named_exprs.items():
190+
expr_list.append(expr.alias(alias).expr)
191+
return expr_list
192+
193+
expressions = _simplify_expression(*exprs, **named_exprs)
194+
195+
return DataFrame(self.df.with_columns(expressions))
196+
163197
def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
164198
r"""Rename one column by applying a new projection.
165199

python/tests/test_dataframe.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,37 @@ def test_with_column(df):
205205
assert result.column(2) == pa.array([5, 7, 9])
206206

207207

208+
def test_with_columns(df):
209+
df = df.with_columns(
210+
(column("a") + column("b")).alias("c"),
211+
(column("a") + column("b")).alias("d"),
212+
[
213+
(column("a") + column("b")).alias("e"),
214+
(column("a") + column("b")).alias("f"),
215+
],
216+
g=(column("a") + column("b")),
217+
)
218+
219+
# execute and collect the first (and only) batch
220+
result = df.collect()[0]
221+
222+
assert result.schema.field(0).name == "a"
223+
assert result.schema.field(1).name == "b"
224+
assert result.schema.field(2).name == "c"
225+
assert result.schema.field(3).name == "d"
226+
assert result.schema.field(4).name == "e"
227+
assert result.schema.field(5).name == "f"
228+
assert result.schema.field(6).name == "g"
229+
230+
assert result.column(0) == pa.array([1, 2, 3])
231+
assert result.column(1) == pa.array([4, 5, 6])
232+
assert result.column(2) == pa.array([5, 7, 9])
233+
assert result.column(3) == pa.array([5, 7, 9])
234+
assert result.column(4) == pa.array([5, 7, 9])
235+
assert result.column(5) == pa.array([5, 7, 9])
236+
assert result.column(6) == pa.array([5, 7, 9])
237+
238+
208239
def test_with_column_renamed(df):
209240
df = df.with_column("c", column("a") + column("b")).with_column_renamed("c", "sum")
210241

src/dataframe.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,16 @@ impl PyDataFrame {
180180
Ok(Self::new(df))
181181
}
182182

183+
fn with_columns(&self, exprs: Vec<PyExpr>) -> PyResult<Self> {
184+
let mut df = self.df.as_ref().clone();
185+
for expr in exprs {
186+
let expr: Expr = expr.into();
187+
let name = format!("{}", expr.schema_name());
188+
df = df.with_column(name.as_str(), expr)?
189+
}
190+
Ok(Self::new(df))
191+
}
192+
183193
/// Rename one column by applying a new projection. This is a no-op if the column to be
184194
/// renamed does not exist.
185195
fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyResult<Self> {

0 commit comments

Comments
 (0)