Skip to content

Commit a93d614

Browse files
authored
with_column supports SQL expression (#1284)
1 parent 29bcb0f commit a93d614

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

python/datafusion/dataframe.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,11 +521,12 @@ def parse_sql_expr(self, expr: str) -> Expr:
521521
"""
522522
return Expr(self.df.parse_sql_expr(expr))
523523

524-
def with_column(self, name: str, expr: Expr) -> DataFrame:
524+
def with_column(self, name: str, expr: Expr | str) -> DataFrame:
525525
"""Add an additional column to the DataFrame.
526526
527527
The ``expr`` must be an :class:`~datafusion.expr.Expr` constructed with
528-
:func:`datafusion.col` or :func:`datafusion.lit`.
528+
:func:`datafusion.col` or :func:`datafusion.lit`, or a SQL expression
529+
string that will be parsed against the DataFrame schema.
529530
530531
Example::
531532
@@ -539,6 +540,8 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
539540
Returns:
540541
DataFrame with the new column.
541542
"""
543+
expr = self.parse_sql_expr(expr) if isinstance(expr, str) else expr
544+
542545
return DataFrame(self.df.with_column(name, ensure_expr(expr)))
543546

544547
def with_columns(

python/tests/test_dataframe.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -477,8 +477,8 @@ def test_tail(df):
477477
assert result.column(2) == pa.array([8])
478478

479479

480-
def test_with_column(df):
481-
df = df.with_column("c", column("a") + column("b"))
480+
def test_with_column_sql_expression(df):
481+
df = df.with_column("c", "a + b")
482482

483483
# execute and collect the first (and only) batch
484484
result = df.collect()[0]
@@ -492,11 +492,19 @@ def test_with_column(df):
492492
assert result.column(2) == pa.array([5, 7, 9])
493493

494494

495-
def test_with_column_invalid_expr(df):
496-
with pytest.raises(
497-
TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)"
498-
):
499-
df.with_column("c", "a")
495+
def test_with_column(df):
496+
df = df.with_column("c", column("a") + column("b"))
497+
498+
# execute and collect the first (and only) batch
499+
result = df.collect()[0]
500+
501+
assert result.schema.field(0).name == "a"
502+
assert result.schema.field(1).name == "b"
503+
assert result.schema.field(2).name == "c"
504+
505+
assert result.column(0) == pa.array([1, 2, 3])
506+
assert result.column(1) == pa.array([4, 5, 6])
507+
assert result.column(2) == pa.array([5, 7, 9])
500508

501509

502510
def test_with_columns(df):

0 commit comments

Comments
 (0)