4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
4343from datafusion .expr import (
44- EXPR_TYPE_ERROR ,
4544 Expr ,
4645 SortKey ,
46+ ensure_expr ,
4747 expr_list_to_raw_expr_list ,
4848 sort_list_to_raw_sort_list ,
4949)
5858 import polars as pl
5959 import pyarrow as pa
6060
61- from datafusion ._internal import expr as expr_internal
62-
6361from enum import Enum
6462
6563
@@ -418,9 +416,17 @@ def filter(self, *predicates: Expr) -> DataFrame:
418416 """Return a DataFrame for which ``predicate`` evaluates to ``True``.
419417
420418 Rows for which ``predicate`` evaluates to ``False`` or ``None`` are filtered
421- out. If more than one predicate is provided, these predicates will be
422- combined as a logical AND. If more complex logic is required, see the
423- logical operations in :py:mod:`~datafusion.functions`.
419+ out. If more than one predicate is provided, these predicates will be
420+ combined as a logical AND. Each ``predicate`` must be an
421+ :class:`~datafusion.expr.Expr` created using helper functions such as
422+ :func:`datafusion.col` or :func:`datafusion.lit`; plain strings are not
423+ accepted. If more complex logic is required, see the logical operations in
424+ :py:mod:`~datafusion.functions`.
425+
426+ Example::
427+
428+ from datafusion import col, lit
429+ df.filter(col("a") > lit(1))
424430
425431 Args:
426432 predicates: Predicate expression(s) to filter the DataFrame.
@@ -430,40 +436,49 @@ def filter(self, *predicates: Expr) -> DataFrame:
430436 """
431437 df = self .df
432438 for p in predicates :
433- if not isinstance (p , Expr ):
434- raise TypeError (EXPR_TYPE_ERROR )
435- df = df .filter (p .expr )
439+ df = df .filter (ensure_expr (p ))
436440 return DataFrame (df )
437441
438442 def with_column (self , name : str , expr : Expr ) -> DataFrame :
439443 """Add an additional column to the DataFrame.
440444
445+ The ``expr`` must be an :class:`~datafusion.expr.Expr` constructed with
446+ :func:`datafusion.col` or :func:`datafusion.lit`; plain strings are not
447+ accepted.
448+
449+ Example::
450+
451+ from datafusion import col, lit
452+ df.with_column("b", col("a") + lit(1))
453+
441454 Args:
442455 name: Name of the column to add.
443456 expr: Expression to compute the column.
444457
445458 Returns:
446459 DataFrame with the new column.
447460 """
448- if not isinstance (expr , Expr ):
449- raise TypeError (EXPR_TYPE_ERROR )
450- return DataFrame (self .df .with_column (name , expr .expr ))
461+ return DataFrame (self .df .with_column (name , ensure_expr (expr )))
451462
452463 def with_columns (
453464 self , * exprs : Expr | Iterable [Expr ], ** named_exprs : Expr
454465 ) -> DataFrame :
455466 """Add columns to the DataFrame.
456467
457- By passing expressions, iteratables of expressions, or named expressions. To
458- pass named expressions use the form name=Expr.
468+ By passing expressions, iterables of expressions, or named expressions.
469+ All expressions must be :class:`~datafusion.expr.Expr` objects created via
470+ :func:`datafusion.col` or :func:`datafusion.lit`; plain strings are not
471+ accepted. To pass named expressions use the form ``name=Expr``.
459472
460- Example usage: The following will add 4 columns labeled a, b, c, and d::
473+ Example usage: The following will add 4 columns labeled ``a``, ``b``, ``c``,
474+ and ``d``::
461475
476+ from datafusion import col, lit
462477 df = df.with_columns(
463- lit(0 ).alias('a' ),
464- [lit(1).alias('b' ), lit(2 ).alias('c' )],
478+ col("x" ).alias("a" ),
479+ [lit(1).alias("b" ), col("y" ).alias("c" )],
465480 d=lit(3)
466- )
481+ )
467482
468483 Args:
469484 exprs: Either a single expression or an iterable of expressions to add.
@@ -473,30 +488,19 @@ def with_columns(
473488 DataFrame with the new columns added.
474489 """
475490
476- def _simplify_expression (
477- * exprs : Expr | Iterable [Expr ], ** named_exprs : Expr
478- ) -> list [expr_internal .Expr ]:
479- expr_list : list [expr_internal .Expr ] = []
480- for expr in exprs :
491+ def _iter_exprs (items : Iterable [Expr | Iterable [Expr ]]) -> Iterable [Expr | str ]:
492+ for expr in items :
481493 if isinstance (expr , str ):
482- raise TypeError (EXPR_TYPE_ERROR )
483- if isinstance (expr , Iterable ) and not isinstance (expr , Expr ):
484- expr_value = list (expr )
485- if any (isinstance (inner , str ) for inner in expr_value ):
486- raise TypeError (EXPR_TYPE_ERROR )
494+ yield expr
495+ elif isinstance (expr , Iterable ) and not isinstance (expr , Expr ):
496+ yield from _iter_exprs (expr )
487497 else :
488- expr_value = expr
489- try :
490- expr_list .extend (expr_list_to_raw_expr_list (expr_value ))
491- except TypeError as err :
492- raise TypeError (EXPR_TYPE_ERROR ) from err
493- for alias , expr in named_exprs .items ():
494- if not isinstance (expr , Expr ):
495- raise TypeError (EXPR_TYPE_ERROR )
496- expr_list .append (expr .alias (alias ).expr )
497- return expr_list
498-
499- expressions = _simplify_expression (* exprs , ** named_exprs )
498+ yield expr
499+
500+ expressions = [ensure_expr (e ) for e in _iter_exprs (exprs )]
501+ for alias , expr in named_exprs .items ():
502+ ensure_expr (expr )
503+ expressions .append (expr .alias (alias ).expr )
500504
501505 return DataFrame (self .df .with_columns (expressions ))
502506
@@ -535,11 +539,7 @@ def aggregate(
535539 aggs_list = aggs if isinstance (aggs , list ) else [aggs ]
536540
537541 group_by_exprs = expr_list_to_raw_expr_list (group_by_list )
538- aggs_exprs = []
539- for agg in aggs_list :
540- if not isinstance (agg , Expr ):
541- raise TypeError (EXPR_TYPE_ERROR )
542- aggs_exprs .append (agg .expr )
542+ aggs_exprs = [ensure_expr (agg ) for agg in aggs_list ]
543543 return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
544544
545545 def sort (self , * exprs : SortKey ) -> DataFrame :
@@ -554,7 +554,7 @@ def sort(self, *exprs: SortKey) -> DataFrame:
554554 Returns:
555555 DataFrame after sorting.
556556 """
557- exprs_raw = sort_list_to_raw_sort_list (list ( exprs ) )
557+ exprs_raw = sort_list_to_raw_sort_list (exprs )
558558 return DataFrame (self .df .sort (* exprs_raw ))
559559
560560 def cast (self , mapping : dict [str , pa .DataType [Any ]]) -> DataFrame :
@@ -766,8 +766,15 @@ def join_on(
766766 ) -> DataFrame :
767767 """Join two :py:class:`DataFrame` using the specified expressions.
768768
769- On expressions are used to support in-equality predicates. Equality
770- predicates are correctly optimized
769+ Join predicates must be :class:`~datafusion.expr.Expr` objects, typically
770+ built with :func:`datafusion.col`; plain strings are not accepted. On
771+ expressions are used to support in-equality predicates. Equality predicates
772+ are correctly optimized.
773+
774+ Example::
775+
776+ from datafusion import col
777+ df.join_on(other_df, col("id") == col("other_id"))
771778
772779 Args:
773780 right: Other DataFrame to join with.
@@ -778,11 +785,7 @@ def join_on(
778785 Returns:
779786 DataFrame after join.
780787 """
781- exprs = []
782- for expr in on_exprs :
783- if not isinstance (expr , Expr ):
784- raise TypeError (EXPR_TYPE_ERROR )
785- exprs .append (expr .expr )
788+ exprs = [ensure_expr (expr ) for expr in on_exprs ]
786789 return DataFrame (self .df .join_on (right .df , exprs , how ))
787790
788791 def explain (self , verbose : bool = False , analyze : bool = False ) -> None :
0 commit comments