Skip to content

Commit 694a5d8

Browse files
feat: Add SQL expression support for with_columns (#1286)
* add SQL expression support for `with_columns` * fix ruff errors * Update python/datafusion/dataframe.py Co-authored-by: Hendrik Makait <[email protected]> * Update python/datafusion/dataframe.py Co-authored-by: Hendrik Makait <[email protected]> * remove parentheses * update example * fix ident --------- Co-authored-by: Hendrik Makait <[email protected]>
1 parent d9c90d2 commit 694a5d8

File tree

2 files changed

+64
-16
lines changed

2 files changed

+64
-16
lines changed

python/datafusion/dataframe.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -545,13 +545,14 @@ def with_column(self, name: str, expr: Expr | str) -> DataFrame:
545545
return DataFrame(self.df.with_column(name, ensure_expr(expr)))
546546

547547
def with_columns(
548-
self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr
548+
self, *exprs: Expr | str | Iterable[Expr | str], **named_exprs: Expr | str
549549
) -> DataFrame:
550550
"""Add columns to the DataFrame.
551551
552-
By passing expressions, iterables of expressions, or named expressions.
552+
By passing expressions, iterables of expressions, string SQL expressions,
553+
or named expressions.
553554
All expressions must be :class:`~datafusion.expr.Expr` objects created via
554-
:func:`datafusion.col` or :func:`datafusion.lit`.
555+
:func:`datafusion.col` or :func:`datafusion.lit`, or SQL expression strings.
555556
To pass named expressions use the form ``name=Expr``.
556557
557558
Example usage: The following will add 4 columns labeled ``a``, ``b``, ``c``,
@@ -564,17 +565,44 @@ def with_columns(
564565
d=lit(3)
565566
)
566567
568+
Equivalent example using just SQL strings:
569+
570+
df = df.with_columns(
571+
"x as a",
572+
["1 as b", "y as c"],
573+
d="3"
574+
)
575+
567576
Args:
568-
exprs: Either a single expression or an iterable of expressions to add.
577+
exprs: Either a single expression, an iterable of expressions to add or
578+
SQL expression strings.
569579
named_exprs: Named expressions in the form of ``name=expr``
570580
571581
Returns:
572582
DataFrame with the new columns added.
573583
"""
574-
expressions = ensure_expr_list(exprs)
584+
expressions = []
585+
for expr in exprs:
586+
if isinstance(expr, str):
587+
expressions.append(self.parse_sql_expr(expr).expr)
588+
elif isinstance(expr, Iterable) and not isinstance(
589+
expr, (Expr, str, bytes, bytearray)
590+
):
591+
expressions.extend(
592+
[
593+
self.parse_sql_expr(e).expr
594+
if isinstance(e, str)
595+
else ensure_expr(e)
596+
for e in expr
597+
]
598+
)
599+
else:
600+
expressions.append(ensure_expr(expr))
601+
575602
for alias, expr in named_exprs.items():
576-
ensure_expr(expr)
577-
expressions.append(expr.alias(alias).expr)
603+
e = self.parse_sql_expr(expr) if isinstance(expr, str) else expr
604+
ensure_expr(e)
605+
expressions.append(e.alias(alias).expr)
578606

579607
return DataFrame(self.df.with_columns(expressions))
580608

python/tests/test_dataframe.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -538,15 +538,35 @@ def test_with_columns(df):
538538
assert result.column(6) == pa.array([5, 7, 9])
539539

540540

541-
def test_with_columns_invalid_expr(df):
542-
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
543-
df.with_columns("a")
544-
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
545-
df.with_columns(c="a")
546-
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
547-
df.with_columns(["a"])
548-
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
549-
df.with_columns(c=["a"])
541+
def test_with_columns_str(df):
542+
df = df.with_columns(
543+
"a + b as c",
544+
"a + b as d",
545+
[
546+
"a + b as e",
547+
"a + b as f",
548+
],
549+
g="a + b",
550+
)
551+
552+
# execute and collect the first (and only) batch
553+
result = df.collect()[0]
554+
555+
assert result.schema.field(0).name == "a"
556+
assert result.schema.field(1).name == "b"
557+
assert result.schema.field(2).name == "c"
558+
assert result.schema.field(3).name == "d"
559+
assert result.schema.field(4).name == "e"
560+
assert result.schema.field(5).name == "f"
561+
assert result.schema.field(6).name == "g"
562+
563+
assert result.column(0) == pa.array([1, 2, 3])
564+
assert result.column(1) == pa.array([4, 5, 6])
565+
assert result.column(2) == pa.array([5, 7, 9])
566+
assert result.column(3) == pa.array([5, 7, 9])
567+
assert result.column(4) == pa.array([5, 7, 9])
568+
assert result.column(5) == pa.array([5, 7, 9])
569+
assert result.column(6) == pa.array([5, 7, 9])
550570

551571

552572
def test_cast(df):

0 commit comments

Comments
 (0)