Skip to content

Commit 54687a2

Browse files
committed
feat: add ensure_expr helper for validation; refine expression handling, sorting, and docs
- Introduce `ensure_expr` helper and improve internal expression validation - Update error messages and tests to consistently use `EXPR_TYPE_ERROR` - Refactor expression handling with `_to_raw_expr`, `_ensure_expr`, and `SortKey` - Improve type safety and consistency in sort key definitions and file sort order - Add parameterized parquet sorting tests - Enhance DataFrame docstrings with clearer guidance and usage examples - Fix minor typos and error message clarity
1 parent 91167b0 commit 54687a2

File tree

6 files changed

+273
-109
lines changed

6 files changed

+273
-109
lines changed

python/datafusion/context.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from __future__ import annotations
2121

2222
import warnings
23-
from typing import TYPE_CHECKING, Any, Protocol
23+
from typing import TYPE_CHECKING, Any, Protocol, Sequence
2424

2525
import pyarrow as pa
2626

@@ -553,7 +553,7 @@ def register_listing_table(
553553
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
554554
file_extension: str = ".parquet",
555555
schema: pa.Schema | None = None,
556-
file_sort_order: list[list[SortKey]] | None = None,
556+
file_sort_order: Sequence[Sequence[SortKey]] | None = None,
557557
) -> None:
558558
"""Register multiple files as a single table.
559559
@@ -805,7 +805,7 @@ def register_parquet(
805805
file_extension: str = ".parquet",
806806
skip_metadata: bool = True,
807807
schema: pa.Schema | None = None,
808-
file_sort_order: list[list[SortKey]] | None = None,
808+
file_sort_order: Sequence[Sequence[SortKey]] | None = None,
809809
) -> None:
810810
"""Register a Parquet file as a table.
811811
@@ -1096,7 +1096,7 @@ def read_parquet(
10961096
file_extension: str = ".parquet",
10971097
skip_metadata: bool = True,
10981098
schema: pa.Schema | None = None,
1099-
file_sort_order: list[list[SortKey]] | None = None,
1099+
file_sort_order: Sequence[Sequence[SortKey]] | None = None,
11001100
) -> DataFrame:
11011101
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
11021102
@@ -1176,8 +1176,16 @@ def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
11761176

11771177
@staticmethod
11781178
def _convert_file_sort_order(
1179-
file_sort_order: list[list[Expr | SortExpr | str]] | None,
1179+
file_sort_order: Sequence[Sequence[SortKey]] | None,
11801180
) -> list[list[Any]] | None:
1181+
"""Convert nested ``SortKey`` sequences into raw sort representations.
1182+
1183+
Each ``SortKey`` can be a column name string, an ``Expr``, or a
1184+
``SortExpr`` and will be converted using
1185+
:func:`datafusion.expr.sort_list_to_raw_sort_list`.
1186+
"""
1187+
# Convert each ``SortKey`` in the provided sort order to the low-level
1188+
# representation expected by the Rust bindings.
11811189
return (
11821190
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
11831191
if file_sort_order is not None

python/datafusion/dataframe.py

Lines changed: 56 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
4343
from datafusion.expr import (
44-
EXPR_TYPE_ERROR,
4544
Expr,
4645
SortKey,
46+
ensure_expr,
4747
expr_list_to_raw_expr_list,
4848
sort_list_to_raw_sort_list,
4949
)
@@ -58,8 +58,6 @@
5858
import polars as pl
5959
import pyarrow as pa
6060

61-
from datafusion._internal import expr as expr_internal
62-
6361
from enum import Enum
6462

6563

@@ -418,9 +416,17 @@ def filter(self, *predicates: Expr) -> DataFrame:
418416
"""Return a DataFrame for which ``predicate`` evaluates to ``True``.
419417
420418
Rows for which ``predicate`` evaluates to ``False`` or ``None`` are filtered
421-
out. If more than one predicate is provided, these predicates will be
422-
combined as a logical AND. If more complex logic is required, see the
423-
logical operations in :py:mod:`~datafusion.functions`.
419+
out. If more than one predicate is provided, these predicates will be
420+
combined as a logical AND. Each ``predicate`` must be an
421+
:class:`~datafusion.expr.Expr` created using helper functions such as
422+
:func:`datafusion.col` or :func:`datafusion.lit`; plain strings are not
423+
accepted. If more complex logic is required, see the logical operations in
424+
:py:mod:`~datafusion.functions`.
425+
426+
Example::
427+
428+
from datafusion import col, lit
429+
df.filter(col("a") > lit(1))
424430
425431
Args:
426432
predicates: Predicate expression(s) to filter the DataFrame.
@@ -430,40 +436,49 @@ def filter(self, *predicates: Expr) -> DataFrame:
430436
"""
431437
df = self.df
432438
for p in predicates:
433-
if not isinstance(p, Expr):
434-
raise TypeError(EXPR_TYPE_ERROR)
435-
df = df.filter(p.expr)
439+
df = df.filter(ensure_expr(p))
436440
return DataFrame(df)
437441

438442
def with_column(self, name: str, expr: Expr) -> DataFrame:
439443
"""Add an additional column to the DataFrame.
440444
445+
The ``expr`` must be an :class:`~datafusion.expr.Expr` constructed with
446+
:func:`datafusion.col` or :func:`datafusion.lit`; plain strings are not
447+
accepted.
448+
449+
Example::
450+
451+
from datafusion import col, lit
452+
df.with_column("b", col("a") + lit(1))
453+
441454
Args:
442455
name: Name of the column to add.
443456
expr: Expression to compute the column.
444457
445458
Returns:
446459
DataFrame with the new column.
447460
"""
448-
if not isinstance(expr, Expr):
449-
raise TypeError(EXPR_TYPE_ERROR)
450-
return DataFrame(self.df.with_column(name, expr.expr))
461+
return DataFrame(self.df.with_column(name, ensure_expr(expr)))
451462

452463
def with_columns(
453464
self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr
454465
) -> DataFrame:
455466
"""Add columns to the DataFrame.
456467
457-
By passing expressions, iteratables of expressions, or named expressions. To
458-
pass named expressions use the form name=Expr.
468+
By passing expressions, iterables of expressions, or named expressions.
469+
All expressions must be :class:`~datafusion.expr.Expr` objects created via
470+
:func:`datafusion.col` or :func:`datafusion.lit`; plain strings are not
471+
accepted. To pass named expressions use the form ``name=Expr``.
459472
460-
Example usage: The following will add 4 columns labeled a, b, c, and d::
473+
Example usage: The following will add 4 columns labeled ``a``, ``b``, ``c``,
474+
and ``d``::
461475
476+
from datafusion import col, lit
462477
df = df.with_columns(
463-
lit(0).alias('a'),
464-
[lit(1).alias('b'), lit(2).alias('c')],
478+
col("x").alias("a"),
479+
[lit(1).alias("b"), col("y").alias("c")],
465480
d=lit(3)
466-
)
481+
)
467482
468483
Args:
469484
exprs: Either a single expression or an iterable of expressions to add.
@@ -473,30 +488,19 @@ def with_columns(
473488
DataFrame with the new columns added.
474489
"""
475490

476-
def _simplify_expression(
477-
*exprs: Expr | Iterable[Expr], **named_exprs: Expr
478-
) -> list[expr_internal.Expr]:
479-
expr_list: list[expr_internal.Expr] = []
480-
for expr in exprs:
491+
def _iter_exprs(items: Iterable[Expr | Iterable[Expr]]) -> Iterable[Expr | str]:
492+
for expr in items:
481493
if isinstance(expr, str):
482-
raise TypeError(EXPR_TYPE_ERROR)
483-
if isinstance(expr, Iterable) and not isinstance(expr, Expr):
484-
expr_value = list(expr)
485-
if any(isinstance(inner, str) for inner in expr_value):
486-
raise TypeError(EXPR_TYPE_ERROR)
494+
yield expr
495+
elif isinstance(expr, Iterable) and not isinstance(expr, Expr):
496+
yield from _iter_exprs(expr)
487497
else:
488-
expr_value = expr
489-
try:
490-
expr_list.extend(expr_list_to_raw_expr_list(expr_value))
491-
except TypeError as err:
492-
raise TypeError(EXPR_TYPE_ERROR) from err
493-
for alias, expr in named_exprs.items():
494-
if not isinstance(expr, Expr):
495-
raise TypeError(EXPR_TYPE_ERROR)
496-
expr_list.append(expr.alias(alias).expr)
497-
return expr_list
498-
499-
expressions = _simplify_expression(*exprs, **named_exprs)
498+
yield expr
499+
500+
expressions = [ensure_expr(e) for e in _iter_exprs(exprs)]
501+
for alias, expr in named_exprs.items():
502+
ensure_expr(expr)
503+
expressions.append(expr.alias(alias).expr)
500504

501505
return DataFrame(self.df.with_columns(expressions))
502506

@@ -535,11 +539,7 @@ def aggregate(
535539
aggs_list = aggs if isinstance(aggs, list) else [aggs]
536540

537541
group_by_exprs = expr_list_to_raw_expr_list(group_by_list)
538-
aggs_exprs = []
539-
for agg in aggs_list:
540-
if not isinstance(agg, Expr):
541-
raise TypeError(EXPR_TYPE_ERROR)
542-
aggs_exprs.append(agg.expr)
542+
aggs_exprs = [ensure_expr(agg) for agg in aggs_list]
543543
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
544544

545545
def sort(self, *exprs: SortKey) -> DataFrame:
@@ -554,7 +554,7 @@ def sort(self, *exprs: SortKey) -> DataFrame:
554554
Returns:
555555
DataFrame after sorting.
556556
"""
557-
exprs_raw = sort_list_to_raw_sort_list(list(exprs))
557+
exprs_raw = sort_list_to_raw_sort_list(exprs)
558558
return DataFrame(self.df.sort(*exprs_raw))
559559

560560
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
@@ -766,8 +766,15 @@ def join_on(
766766
) -> DataFrame:
767767
"""Join two :py:class:`DataFrame` using the specified expressions.
768768
769-
On expressions are used to support in-equality predicates. Equality
770-
predicates are correctly optimized
769+
Join predicates must be :class:`~datafusion.expr.Expr` objects, typically
770+
built with :func:`datafusion.col`; plain strings are not accepted. On
771+
expressions are used to support in-equality predicates. Equality predicates
772+
are correctly optimized.
773+
774+
Example::
775+
776+
from datafusion import col
777+
df.join_on(other_df, col("id") == col("other_id"))
771778
772779
Args:
773780
right: Other DataFrame to join with.
@@ -778,11 +785,7 @@ def join_on(
778785
Returns:
779786
DataFrame after join.
780787
"""
781-
exprs = []
782-
for expr in on_exprs:
783-
if not isinstance(expr, Expr):
784-
raise TypeError(EXPR_TYPE_ERROR)
785-
exprs.append(expr.expr)
788+
exprs = [ensure_expr(expr) for expr in on_exprs]
786789
return DataFrame(self.df.join_on(right.df, exprs, how))
787790

788791
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/datafusion/expr.py

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence, Union
25+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
2626

2727
import pyarrow as pa
2828

@@ -41,9 +41,8 @@
4141

4242

4343
# Standard error message for invalid expression types
44-
EXPR_TYPE_ERROR = "Use col() or lit() to construct expressions"
45-
46-
SortKey = Union["Expr", "SortExpr", str]
44+
# Mention both alias forms of column and literal helpers
45+
EXPR_TYPE_ERROR = "Use col()/column() or lit()/literal() to construct expressions"
4746

4847
# The following are imported from the internal representation. We may choose to
4948
# give these all proper wrappers, or to simply leave as is. These were added
@@ -219,9 +218,54 @@
219218
"WindowExpr",
220219
"WindowFrame",
221220
"WindowFrameBound",
221+
"ensure_expr",
222222
]
223223

224224

225+
def ensure_expr(value: Expr | Any) -> expr_internal.Expr:
226+
"""Return the internal expression from ``Expr`` or raise ``TypeError``.
227+
228+
This helper rejects plain strings and other non-:class:`Expr` values so
229+
higher level APIs consistently require explicit :func:`~datafusion.col` or
230+
:func:`~datafusion.lit` expressions.
231+
232+
Args:
233+
value: Candidate expression or other object.
234+
235+
Returns:
236+
The internal expression representation.
237+
238+
Raises:
239+
TypeError: If ``value`` is not an instance of :class:`Expr`.
240+
"""
241+
if not isinstance(value, Expr):
242+
raise TypeError(EXPR_TYPE_ERROR)
243+
return value.expr
244+
245+
246+
def _to_raw_expr(value: Expr | str) -> expr_internal.Expr:
247+
"""Convert a Python expression or column name to its raw variant.
248+
249+
Args:
250+
value: Candidate expression or column name.
251+
252+
Returns:
253+
The internal :class:`~datafusion._internal.expr.Expr` representation.
254+
255+
Raises:
256+
TypeError: If ``value`` is neither an :class:`Expr` nor ``str``.
257+
"""
258+
if isinstance(value, str):
259+
return Expr.column(value).expr
260+
if isinstance(value, Expr):
261+
return value.expr
262+
error = (
263+
"Expected Expr or column name, found:"
264+
f" {type(value).__name__}. {EXPR_TYPE_ERROR}."
265+
)
266+
raise TypeError(error)
267+
268+
225269
def expr_list_to_raw_expr_list(
226270
expr_list: Optional[Sequence[Expr | str] | Expr | str],
227271
) -> Optional[list[expr_internal.Expr]]:
@@ -230,30 +274,18 @@ def expr_list_to_raw_expr_list(
230274
expr_list = [expr_list]
231275
if expr_list is None:
232276
return None
233-
raw_exprs: list[expr_internal.Expr] = []
234-
for e in expr_list:
235-
if isinstance(e, str):
236-
raw_exprs.append(Expr.column(e).expr)
237-
elif isinstance(e, Expr):
238-
raw_exprs.append(e.expr)
239-
else:
240-
error = (
241-
"Expected Expr or column name, found:"
242-
f" {type(e).__name__}. {EXPR_TYPE_ERROR}."
243-
)
244-
raise TypeError(error)
245-
return raw_exprs
277+
return [_to_raw_expr(e) for e in expr_list]
246278

247279

248280
def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
249-
"""Helper function to return a default Sort if an Expr is provided."""
281+
"""Return a :class:`SortExpr`, defaulting attributes when necessary."""
250282
if isinstance(e, SortExpr):
251283
return e.raw_sort
252284
return SortExpr(e, ascending=True, nulls_first=True).raw_sort
253285

254286

255287
def sort_list_to_raw_sort_list(
256-
sort_list: Optional[list[SortKey] | SortKey],
288+
sort_list: Optional[Sequence[SortKey] | SortKey],
257289
) -> Optional[list[expr_internal.SortExpr]]:
258290
"""Helper function to return an optional sort list to raw variant."""
259291
if isinstance(sort_list, (Expr, SortExpr, str)):
@@ -262,17 +294,11 @@ def sort_list_to_raw_sort_list(
262294
return None
263295
raw_sort_list = []
264296
for item in sort_list:
265-
if isinstance(item, str):
266-
expr_obj = Expr.column(item)
267-
elif isinstance(item, (Expr, SortExpr)):
268-
expr_obj = item
297+
if isinstance(item, SortExpr):
298+
raw_sort_list.append(sort_or_default(item))
269299
else:
270-
error = (
271-
"Expected Expr or column name, found:"
272-
f" {type(item).__name__}. {EXPR_TYPE_ERROR}."
273-
)
274-
raise TypeError(error)
275-
raw_sort_list.append(sort_or_default(expr_obj))
300+
raw_expr = _to_raw_expr(item) # may raise ``TypeError``
301+
raw_sort_list.append(sort_or_default(Expr(raw_expr)))
276302
return raw_sort_list
277303

278304

@@ -1335,3 +1361,6 @@ def nulls_first(self) -> bool:
13351361
def __repr__(self) -> str:
13361362
"""Generate a string representation of this expression."""
13371363
return self.raw_sort.__repr__()
1364+
1365+
1366+
SortKey = Expr | SortExpr | str

0 commit comments

Comments
 (0)