Skip to content

Commit 91167b0

Browse files
committed
refactor: unify expression and sorting logic; improve docs and error handling
- Update `order_by` handling in Window class for better type support - Improve type checking in DataFrame expression handling - Replace `Expr`/`SortExpr` with `SortKey` in file_sort_order and related functions - Simplify file_sort_order handling in SessionContext - Rename `_EXPR_TYPE_ERROR` → `EXPR_TYPE_ERROR` for consistency - Clarify usage of `col()` vs `column()` in DataFrame examples - Enhance documentation for file_sort_order in SessionContext
1 parent f9cafb8 commit 91167b0

File tree

6 files changed

+85
-66
lines changed

6 files changed

+85
-66
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,21 +142,21 @@ For such methods, you can pass column names directly:
142142

143143
.. code-block:: python
144144
145-
from datafusion import col, column, functions as f
145+
from datafusion import col, functions as f
146146
147147
df.sort('id')
148148
df.aggregate('id', [f.count(col('value'))])
149149
150-
The same operation can also be written with an explicit column expression:
150+
The same operation can also be written with explicit column expressions, using either ``col()`` or ``column()``:
151151

152152
.. code-block:: python
153153
154154
from datafusion import col, column, functions as f
155155
156156
df.sort(col('id'))
157-
df.aggregate(col('id'), [f.count(col('value'))])
157+
df.aggregate(column('id'), [f.count(col('value'))])
158158
159-
Note that ``column()`` is an alias of ``col()``, so you can use either name.
159+
Note that ``column()`` is an alias of ``col()``, so you can use either name; the example above shows both in action.
160160

161161
Whenever an argument represents an expression—such as in
162162
:py:meth:`~datafusion.DataFrame.filter` or

python/datafusion/context.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from datafusion.catalog import Catalog, CatalogProvider, Table
3333
from datafusion.dataframe import DataFrame
34-
from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
34+
from datafusion.expr import SortKey, sort_list_to_raw_sort_list
3535
from datafusion.record_batch import RecordBatchStream
3636
from datafusion.user_defined import AggregateUDF, ScalarUDF, TableFunction, WindowUDF
3737

@@ -553,7 +553,7 @@ def register_listing_table(
553553
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
554554
file_extension: str = ".parquet",
555555
schema: pa.Schema | None = None,
556-
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
556+
file_sort_order: list[list[SortKey]] | None = None,
557557
) -> None:
558558
"""Register multiple files as a single table.
559559
@@ -567,23 +567,20 @@ def register_listing_table(
567567
table_partition_cols: Partition columns.
568568
file_extension: File extension of the provided table.
569569
schema: The data source schema.
570-
file_sort_order: Sort order for the file.
570+
file_sort_order: Sort order for the file. Each sort key can be
571+
specified as a column name (``str``), an expression
572+
(``Expr``), or a ``SortExpr``.
571573
"""
572574
if table_partition_cols is None:
573575
table_partition_cols = []
574576
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
575-
file_sort_order_raw = (
576-
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
577-
if file_sort_order is not None
578-
else None
579-
)
580577
self.ctx.register_listing_table(
581578
name,
582579
str(path),
583580
table_partition_cols,
584581
file_extension,
585582
schema,
586-
file_sort_order_raw,
583+
self._convert_file_sort_order(file_sort_order),
587584
)
588585

589586
def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
@@ -808,7 +805,7 @@ def register_parquet(
808805
file_extension: str = ".parquet",
809806
skip_metadata: bool = True,
810807
schema: pa.Schema | None = None,
811-
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
808+
file_sort_order: list[list[SortKey]] | None = None,
812809
) -> None:
813810
"""Register a Parquet file as a table.
814811
@@ -827,7 +824,9 @@ def register_parquet(
827824
that may be in the file schema. This can help avoid schema
828825
conflicts due to metadata.
829826
schema: The data source schema.
830-
file_sort_order: Sort order for the file.
827+
file_sort_order: Sort order for the file. Each sort key can be
828+
specified as a column name (``str``), an expression
829+
(``Expr``), or a ``SortExpr``.
831830
"""
832831
if table_partition_cols is None:
833832
table_partition_cols = []
@@ -840,9 +839,7 @@ def register_parquet(
840839
file_extension,
841840
skip_metadata,
842841
schema,
843-
[sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order]
844-
if file_sort_order is not None
845-
else None,
842+
self._convert_file_sort_order(file_sort_order),
846843
)
847844

848845
def register_csv(
@@ -1099,7 +1096,7 @@ def read_parquet(
10991096
file_extension: str = ".parquet",
11001097
skip_metadata: bool = True,
11011098
schema: pa.Schema | None = None,
1102-
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
1099+
file_sort_order: list[list[SortKey]] | None = None,
11031100
) -> DataFrame:
11041101
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
11051102
@@ -1116,19 +1113,17 @@ def read_parquet(
11161113
schema: An optional schema representing the parquet files. If None,
11171114
the parquet reader will try to infer it based on data in the
11181115
file.
1119-
file_sort_order: Sort order for the file.
1116+
file_sort_order: Sort order for the file. Each sort key can be
1117+
specified as a column name (``str``), an expression
1118+
(``Expr``), or a ``SortExpr``.
11201119
11211120
Returns:
11221121
DataFrame representation of the read Parquet files
11231122
"""
11241123
if table_partition_cols is None:
11251124
table_partition_cols = []
11261125
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
1127-
file_sort_order = (
1128-
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
1129-
if file_sort_order is not None
1130-
else None
1131-
)
1126+
file_sort_order = self._convert_file_sort_order(file_sort_order)
11321127
return DataFrame(
11331128
self.ctx.read_parquet(
11341129
str(path),
@@ -1179,6 +1174,16 @@ def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
11791174
"""Execute the ``plan`` and return the results."""
11801175
return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions))
11811176

1177+
@staticmethod
1178+
def _convert_file_sort_order(
1179+
file_sort_order: list[list[Expr | SortExpr | str]] | None,
1180+
) -> list[list[Any]] | None:
1181+
return (
1182+
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
1183+
if file_sort_order is not None
1184+
else None
1185+
)
1186+
11821187
@staticmethod
11831188
def _convert_table_partition_cols(
11841189
table_partition_cols: list[tuple[str, str | pa.DataType]],

python/datafusion/dataframe.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
4343
from datafusion.expr import (
44-
_EXPR_TYPE_ERROR,
44+
EXPR_TYPE_ERROR,
4545
Expr,
46-
SortExpr,
46+
SortKey,
4747
expr_list_to_raw_expr_list,
4848
sort_list_to_raw_sort_list,
4949
)
@@ -431,7 +431,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
431431
df = self.df
432432
for p in predicates:
433433
if not isinstance(p, Expr):
434-
raise TypeError(_EXPR_TYPE_ERROR)
434+
raise TypeError(EXPR_TYPE_ERROR)
435435
df = df.filter(p.expr)
436436
return DataFrame(df)
437437

@@ -446,7 +446,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
446446
DataFrame with the new column.
447447
"""
448448
if not isinstance(expr, Expr):
449-
raise TypeError(_EXPR_TYPE_ERROR)
449+
raise TypeError(EXPR_TYPE_ERROR)
450450
return DataFrame(self.df.with_column(name, expr.expr))
451451

452452
def with_columns(
@@ -478,19 +478,21 @@ def _simplify_expression(
478478
) -> list[expr_internal.Expr]:
479479
expr_list: list[expr_internal.Expr] = []
480480
for expr in exprs:
481-
if isinstance(expr, str) or (
482-
isinstance(expr, Iterable)
483-
and not isinstance(expr, Expr)
484-
and any(isinstance(inner, str) for inner in expr)
485-
):
486-
raise TypeError(_EXPR_TYPE_ERROR)
481+
if isinstance(expr, str):
482+
raise TypeError(EXPR_TYPE_ERROR)
483+
if isinstance(expr, Iterable) and not isinstance(expr, Expr):
484+
expr_value = list(expr)
485+
if any(isinstance(inner, str) for inner in expr_value):
486+
raise TypeError(EXPR_TYPE_ERROR)
487+
else:
488+
expr_value = expr
487489
try:
488-
expr_list.extend(expr_list_to_raw_expr_list(expr))
490+
expr_list.extend(expr_list_to_raw_expr_list(expr_value))
489491
except TypeError as err:
490-
raise TypeError(_EXPR_TYPE_ERROR) from err
492+
raise TypeError(EXPR_TYPE_ERROR) from err
491493
for alias, expr in named_exprs.items():
492494
if not isinstance(expr, Expr):
493-
raise TypeError(_EXPR_TYPE_ERROR)
495+
raise TypeError(EXPR_TYPE_ERROR)
494496
expr_list.append(expr.alias(alias).expr)
495497
return expr_list
496498

@@ -536,11 +538,11 @@ def aggregate(
536538
aggs_exprs = []
537539
for agg in aggs_list:
538540
if not isinstance(agg, Expr):
539-
raise TypeError(_EXPR_TYPE_ERROR)
541+
raise TypeError(EXPR_TYPE_ERROR)
540542
aggs_exprs.append(agg.expr)
541543
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
542544

543-
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
545+
def sort(self, *exprs: SortKey) -> DataFrame:
544546
"""Sort the DataFrame by the specified sorting expressions or column names.
545547
546548
Note that any expression can be turned into a sort expression by
@@ -779,7 +781,7 @@ def join_on(
779781
exprs = []
780782
for expr in on_exprs:
781783
if not isinstance(expr, Expr):
782-
raise TypeError(_EXPR_TYPE_ERROR)
784+
raise TypeError(EXPR_TYPE_ERROR)
783785
exprs.append(expr.expr)
784786
return DataFrame(self.df.join_on(right.df, exprs, how))
785787

python/datafusion/expr.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
25+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence, Union
2626

2727
import pyarrow as pa
2828

@@ -41,7 +41,9 @@
4141

4242

4343
# Standard error message for invalid expression types
44-
_EXPR_TYPE_ERROR = "Use col() or lit() to construct expressions"
44+
EXPR_TYPE_ERROR = "Use col() or lit() to construct expressions"
45+
46+
SortKey = Union["Expr", "SortExpr", str]
4547

4648
# The following are imported from the internal representation. We may choose to
4749
# give these all proper wrappers, or to simply leave as is. These were added
@@ -199,6 +201,7 @@
199201
"SimilarTo",
200202
"Sort",
201203
"SortExpr",
204+
"SortKey",
202205
"Subquery",
203206
"SubqueryAlias",
204207
"TableScan",
@@ -236,7 +239,7 @@ def expr_list_to_raw_expr_list(
236239
else:
237240
error = (
238241
"Expected Expr or column name, found:"
239-
f" {type(e).__name__}. {_EXPR_TYPE_ERROR}."
242+
f" {type(e).__name__}. {EXPR_TYPE_ERROR}."
240243
)
241244
raise TypeError(error)
242245
return raw_exprs
@@ -250,7 +253,7 @@ def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
250253

251254

252255
def sort_list_to_raw_sort_list(
253-
sort_list: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str],
256+
sort_list: Optional[list[SortKey] | SortKey],
254257
) -> Optional[list[expr_internal.SortExpr]]:
255258
"""Helper function to return an optional sort list to raw variant."""
256259
if isinstance(sort_list, (Expr, SortExpr, str)):
@@ -266,7 +269,7 @@ def sort_list_to_raw_sort_list(
266269
else:
267270
error = (
268271
"Expected Expr or column name, found:"
269-
f" {type(item).__name__}. {_EXPR_TYPE_ERROR}."
272+
f" {type(item).__name__}. {EXPR_TYPE_ERROR}."
270273
)
271274
raise TypeError(error)
272275
raw_sort_list.append(sort_or_default(expr_obj))
@@ -693,7 +696,7 @@ def over(self, window: Window) -> Expr:
693696
window: Window definition
694697
"""
695698
partition_by_raw = expr_list_to_raw_expr_list(window._partition_by)
696-
order_by_raw = sort_list_to_raw_sort_list(window._order_by)
699+
order_by_raw = window._order_by
697700
window_frame_raw = (
698701
window._window_frame.window_frame
699702
if window._window_frame is not None
@@ -1179,7 +1182,7 @@ def __init__(
11791182
self,
11801183
partition_by: Optional[list[Expr] | Expr] = None,
11811184
window_frame: Optional[WindowFrame] = None,
1182-
order_by: Optional[list[SortExpr | Expr] | Expr | SortExpr] = None,
1185+
order_by: Optional[list[SortExpr | Expr | str] | Expr | SortExpr | str] = None,
11831186
null_treatment: Optional[NullTreatment] = None,
11841187
) -> None:
11851188
"""Construct a window definition.
@@ -1192,7 +1195,7 @@ def __init__(
11921195
"""
11931196
self._partition_by = partition_by
11941197
self._window_frame = window_frame
1195-
self._order_by = order_by
1198+
self._order_by = sort_list_to_raw_sort_list(order_by)
11961199
self._null_treatment = null_treatment
11971200

11981201

0 commit comments

Comments
 (0)