Skip to content

Commit f591617

Browse files
committed
Refactor and enhance expression handling, test coverage, and documentation
- Introduced `ensure_expr_list` to validate and flatten nested expressions, treating strings as atomic - Updated expression utilities to improve consistency across aggregation and window functions - Consolidated and expanded parameterized tests for string equivalence in ranking and window functions - Exposed `EXPR_TYPE_ERROR` for consistent error messaging across modules and tests - Improved internal sort logic using `expr_internal.SortExpr` - Clarified expectations for `join_on` expressions in documentation - Standardized imports and improved test clarity for maintainability
1 parent 54687a2 commit f591617

File tree

6 files changed

+176
-29
lines changed

6 files changed

+176
-29
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ existing column. These include:
138138
* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
139139
* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
140140

141+
Note that :py:meth:`~datafusion.DataFrame.join_on` expects ``col()``/``column()`` expressions rather than plain strings.
142+
141143
For such methods, you can pass column names directly:
142144

143145
.. code-block:: python

python/datafusion/context.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from __future__ import annotations
2121

2222
import warnings
23-
from typing import TYPE_CHECKING, Any, Protocol, Sequence
23+
from typing import TYPE_CHECKING, Any, Protocol
24+
from collections.abc import Sequence
2425

2526
import pyarrow as pa
2627

@@ -39,6 +40,7 @@
3940
from ._internal import SessionConfig as SessionConfigInternal
4041
from ._internal import SessionContext as SessionContextInternal
4142
from ._internal import SQLOptions as SQLOptionsInternal
43+
from ._internal import expr as expr_internal
4244

4345
if TYPE_CHECKING:
4446
import pathlib
@@ -1177,8 +1179,8 @@ def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
11771179
@staticmethod
11781180
def _convert_file_sort_order(
11791181
file_sort_order: Sequence[Sequence[SortKey]] | None,
1180-
) -> list[list[Any]] | None:
1181-
"""Convert nested ``SortKey`` sequences into raw sort representations.
1182+
) -> list[list[expr_internal.SortExpr]] | None:
1183+
"""Convert nested ``SortKey`` sequences into raw sort expressions.
11821184
11831185
Each ``SortKey`` can be a column name string, an ``Expr``, or a
11841186
``SortExpr`` and will be converted using

python/datafusion/dataframe.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from __future__ import annotations
2323

2424
import warnings
25+
from collections.abc import Sequence
2526
from typing import (
2627
TYPE_CHECKING,
2728
Any,
@@ -44,6 +45,7 @@
4445
Expr,
4546
SortKey,
4647
ensure_expr,
48+
ensure_expr_list,
4749
expr_list_to_raw_expr_list,
4850
sort_list_to_raw_sort_list,
4951
)
@@ -52,7 +54,7 @@
5254

5355
if TYPE_CHECKING:
5456
import pathlib
55-
from typing import Callable, Sequence
57+
from typing import Callable
5658

5759
import pandas as pd
5860
import polars as pl
@@ -487,17 +489,7 @@ def with_columns(
487489
Returns:
488490
DataFrame with the new columns added.
489491
"""
490-
491-
def _iter_exprs(items: Iterable[Expr | Iterable[Expr]]) -> Iterable[Expr | str]:
492-
for expr in items:
493-
if isinstance(expr, str):
494-
yield expr
495-
elif isinstance(expr, Iterable) and not isinstance(expr, Expr):
496-
yield from _iter_exprs(expr)
497-
else:
498-
yield expr
499-
500-
expressions = [ensure_expr(e) for e in _iter_exprs(exprs)]
492+
expressions = ensure_expr_list(exprs)
501493
for alias, expr in named_exprs.items():
502494
ensure_expr(expr)
503495
expressions.append(expr.alias(alias).expr)
@@ -523,23 +515,31 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
523515

524516
def aggregate(
525517
self,
526-
group_by: list[Expr | str] | Expr | str,
527-
aggs: list[Expr] | Expr,
518+
group_by: Sequence[Expr | str] | Expr | str,
519+
aggs: Sequence[Expr] | Expr,
528520
) -> DataFrame:
529521
"""Aggregates the rows of the current DataFrame.
530522
531523
Args:
532-
group_by: List of expressions or column names to group by.
533-
aggs: List of expressions to aggregate.
524+
group_by: Sequence of expressions or column names to group by.
525+
aggs: Sequence of expressions to aggregate.
534526
535527
Returns:
536528
DataFrame after aggregation.
537529
"""
538-
group_by_list = group_by if isinstance(group_by, list) else [group_by]
539-
aggs_list = aggs if isinstance(aggs, list) else [aggs]
530+
group_by_list = (
531+
list(group_by)
532+
if isinstance(group_by, Sequence) and not isinstance(group_by, (Expr, str))
533+
else [group_by]
534+
)
535+
aggs_list = (
536+
list(aggs)
537+
if isinstance(aggs, Sequence) and not isinstance(aggs, Expr)
538+
else [aggs]
539+
)
540540

541541
group_by_exprs = expr_list_to_raw_expr_list(group_by_list)
542-
aggs_exprs = [ensure_expr(agg) for agg in aggs_list]
542+
aggs_exprs = ensure_expr_list(aggs_list)
543543
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
544544

545545
def sort(self, *exprs: SortKey) -> DataFrame:

python/datafusion/expr.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
25+
from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional
26+
from collections.abc import Sequence
2627

2728
import pyarrow as pa
2829

@@ -131,6 +132,7 @@
131132
WindowExpr = expr_internal.WindowExpr
132133

133134
__all__ = [
135+
"EXPR_TYPE_ERROR",
134136
"Aggregate",
135137
"AggregateFunction",
136138
"Alias",
@@ -219,6 +221,7 @@
219221
"WindowFrame",
220222
"WindowFrameBound",
221223
"ensure_expr",
224+
"ensure_expr_list",
222225
]
223226

224227

@@ -243,6 +246,34 @@ def ensure_expr(value: Expr | Any) -> expr_internal.Expr:
243246
return value.expr
244247

245248

249+
def ensure_expr_list(
250+
exprs: Iterable[Expr | Iterable[Expr]],
251+
) -> list[expr_internal.Expr]:
252+
"""Flatten an iterable of expressions, validating each via ``ensure_expr``.
253+
254+
Args:
255+
exprs: Possibly nested iterable containing expressions.
256+
257+
Returns:
258+
A flat list of raw expressions.
259+
260+
Raises:
261+
TypeError: If any item is not an instance of :class:`Expr`.
262+
"""
263+
264+
def _iter(items: Iterable[Expr | Iterable[Expr]]) -> Iterable[expr_internal.Expr]:
265+
for expr in items:
266+
if isinstance(expr, Iterable) and not isinstance(
267+
expr, (Expr, str, bytes, bytearray)
268+
):
269+
# Treat string-like objects as atomic to surface standard errors
270+
yield from _iter(expr)
271+
else:
272+
yield ensure_expr(expr)
273+
274+
return list(_iter(exprs))
275+
276+
246277
def _to_raw_expr(value: Expr | str) -> expr_internal.Expr:
247278
"""Convert a Python expression or column name to its raw variant.
248279

python/tests/test_dataframe.py

Lines changed: 102 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,18 @@ def test_aggregate_string_and_expression_equivalent(df):
303303
assert result_str == result_expr
304304

305305

306+
def test_aggregate_tuple_group_by(df):
307+
result_list = df.aggregate(["a"], [f.count()]).sort("a").to_pydict()
308+
result_tuple = df.aggregate(("a",), [f.count()]).sort("a").to_pydict()
309+
assert result_tuple == result_list
310+
311+
312+
def test_aggregate_tuple_aggs(df):
313+
result_list = df.aggregate("a", [f.count()]).sort("a").to_pydict()
314+
result_tuple = df.aggregate("a", (f.count(),)).sort("a").to_pydict()
315+
assert result_tuple == result_list
316+
317+
306318
def test_filter_string_unsupported(df):
307319
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
308320
df.filter("a > 1")
@@ -416,14 +428,14 @@ def test_with_columns(df):
416428

417429

418430
def test_with_columns_invalid_expr(df):
419-
with pytest.raises(
420-
TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)"
421-
):
431+
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
422432
df.with_columns("a")
423-
with pytest.raises(
424-
TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)"
425-
):
433+
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
426434
df.with_columns(c="a")
435+
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
436+
df.with_columns(["a"])
437+
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
438+
df.with_columns(c=["a"])
427439

428440

429441
def test_cast(df):
@@ -843,6 +855,27 @@ def test_window_functions(partitioned_df, name, expr, result):
843855
assert table.sort_by("a").to_pydict() == expected
844856

845857

858+
@pytest.mark.parametrize("partition", ["c", df_col("c")])
859+
def test_rank_partition_by_accepts_string(partitioned_df, partition):
860+
"""Passing a string to partition_by should match using col()."""
861+
df = partitioned_df.select(
862+
f.rank(order_by=column("a"), partition_by=partition).alias("r")
863+
)
864+
table = pa.Table.from_batches(df.sort(column("a")).collect())
865+
assert table.column("r").to_pylist() == [1, 2, 3, 4, 1, 2, 3]
866+
867+
868+
@pytest.mark.parametrize("partition", ["c", df_col("c")])
869+
def test_window_partition_by_accepts_string(partitioned_df, partition):
870+
"""Window.partition_by accepts string identifiers."""
871+
expr = f.first_value(column("a")).over(
872+
Window(partition_by=partition, order_by=column("b"))
873+
)
874+
df = partitioned_df.select(expr.alias("fv"))
875+
table = pa.Table.from_batches(df.sort(column("a")).collect())
876+
assert table.column("fv").to_pylist() == [1, 1, 1, 1, 5, 5, 5]
877+
878+
846879
@pytest.mark.parametrize(
847880
("units", "start_bound", "end_bound"),
848881
[
@@ -913,6 +946,69 @@ def test_window_frame_defaults_match_postgres(partitioned_df):
913946
assert df_2.sort(col_a).to_pydict() == expected
914947

915948

949+
def _build_last_value_df(df):
950+
return df.select(
951+
f.last_value(column("a"))
952+
.over(
953+
Window(
954+
partition_by=[column("c")],
955+
order_by=[column("b")],
956+
window_frame=WindowFrame("rows", None, None),
957+
)
958+
)
959+
.alias("expr"),
960+
f.last_value(column("a"))
961+
.over(
962+
Window(
963+
partition_by=[column("c")],
964+
order_by="b",
965+
window_frame=WindowFrame("rows", None, None),
966+
)
967+
)
968+
.alias("str"),
969+
)
970+
971+
972+
def _build_nth_value_df(df):
973+
return df.select(
974+
f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])).alias("expr"),
975+
f.nth_value(column("b"), 3).over(Window(order_by="a")).alias("str"),
976+
)
977+
978+
979+
def _build_rank_df(df):
980+
return df.select(
981+
f.rank(order_by=[column("b")]).alias("expr"),
982+
f.rank(order_by="b").alias("str"),
983+
)
984+
985+
986+
def _build_array_agg_df(df):
987+
return df.aggregate(
988+
[column("c")],
989+
[
990+
f.array_agg(column("a"), order_by=[column("a")]).alias("expr"),
991+
f.array_agg(column("a"), order_by="a").alias("str"),
992+
],
993+
).sort(column("c"))
994+
995+
996+
@pytest.mark.parametrize(
997+
("builder", "expected"),
998+
[
999+
pytest.param(_build_last_value_df, [3, 3, 3, 3, 6, 6, 6], id="last_value"),
1000+
pytest.param(_build_nth_value_df, [None, None, 7, 7, 7, 7, 7], id="nth_value"),
1001+
pytest.param(_build_rank_df, [1, 1, 3, 3, 5, 6, 6], id="rank"),
1002+
pytest.param(_build_array_agg_df, [[0, 1, 2, 3], [4, 5, 6]], id="array_agg"),
1003+
],
1004+
)
1005+
def test_order_by_string_equivalence(partitioned_df, builder, expected):
1006+
df = builder(partitioned_df)
1007+
table = pa.Table.from_batches(df.collect())
1008+
assert table.column("expr").to_pylist() == expected
1009+
assert table.column("expr").to_pylist() == table.column("str").to_pylist()
1010+
1011+
9161012
def test_html_formatter_cell_dimension(df, clean_formatter_state):
9171013
"""Test configuring the HTML formatter with different options."""
9181014
# Configure with custom settings

python/tests/test_expr.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
TransactionStart,
5151
Values,
5252
ensure_expr,
53+
ensure_expr_list,
5354
)
5455

5556

@@ -890,3 +891,18 @@ def test_ensure_expr():
890891
assert ensure_expr(e) is e.expr
891892
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
892893
ensure_expr("a")
894+
895+
896+
def test_ensure_expr_list_string():
897+
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
898+
ensure_expr_list("a")
899+
900+
901+
def test_ensure_expr_list_bytes():
902+
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
903+
ensure_expr_list(b"a")
904+
905+
906+
def test_ensure_expr_list_bytearray():
907+
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
908+
ensure_expr_list(bytearray(b"a"))

0 commit comments

Comments
 (0)