diff --git a/docs/source/user-guide/dataframe/index.rst b/docs/source/user-guide/dataframe/index.rst index 1387db0bd..6d82f7078 100644 --- a/docs/source/user-guide/dataframe/index.rst +++ b/docs/source/user-guide/dataframe/index.rst @@ -95,8 +95,9 @@ DataFusion's DataFrame API offers a wide range of operations: # Select with expressions df = df.select(column("a") + column("b"), column("a") - column("b")) - # Filter rows + # Filter rows (expressions or SQL strings) df = df.filter(column("age") > literal(25)) + df = df.filter("age > 25") # Add computed columns df = df.with_column("full_name", column("first_name") + literal(" ") + column("last_name")) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index d15111d57..0b77d42ef 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -466,31 +466,37 @@ def drop(self, *columns: str) -> DataFrame: return DataFrame(self.df.drop(*normalized_columns)) - def filter(self, *predicates: Expr) -> DataFrame: + def filter(self, *predicates: Expr | str) -> DataFrame: """Return a DataFrame for which ``predicate`` evaluates to ``True``. Rows for which ``predicate`` evaluates to ``False`` or ``None`` are filtered out. If more than one predicate is provided, these predicates will be - combined as a logical AND. Each ``predicate`` must be an + combined as a logical AND. Each ``predicate`` can be an :class:`~datafusion.expr.Expr` created using helper functions such as - :func:`datafusion.col` or :func:`datafusion.lit`. - If more complex logic is required, see the logical operations in - :py:mod:`~datafusion.functions`. + :func:`datafusion.col` or :func:`datafusion.lit`, or a SQL expression string + that will be parsed against the DataFrame schema. If more complex logic is + required, see the logical operations in :py:mod:`~datafusion.functions`. Example:: from datafusion import col, lit df.filter(col("a") > lit(1)) + df.filter("a > 1") Args: - predicates: Predicate expression(s) to filter the DataFrame. + predicates: Predicate expression(s) or SQL strings to filter the DataFrame. Returns: DataFrame after filtering. """ df = self.df - for p in predicates: - df = df.filter(ensure_expr(p)) + for predicate in predicates: + expr = ( + self.parse_sql_expr(predicate) + if isinstance(predicate, str) + else predicate + ) + df = df.filter(ensure_expr(expr)) return DataFrame(df) def parse_sql_expr(self, expr: str) -> Expr: diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 9317711f4..aceebadb4 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -306,6 +306,29 @@ def test_filter(df): assert result.column(2) == pa.array([5]) +def test_filter_string_predicates(df): + df_str = df.filter("a > 2") + result = df_str.collect()[0] + + assert result.column(0) == pa.array([3]) + assert result.column(1) == pa.array([6]) + assert result.column(2) == pa.array([8]) + + df_mixed = df.filter("a > 1", column("b") != literal(6)) + result_mixed = df_mixed.collect()[0] + + assert result_mixed.column(0) == pa.array([2]) + assert result_mixed.column(1) == pa.array([5]) + assert result_mixed.column(2) == pa.array([5]) + + df_strings = df.filter("a > 1", "b < 6") + result_strings = df_strings.collect()[0] + + assert result_strings.column(0) == pa.array([2]) + assert result_strings.column(1) == pa.array([5]) + assert result_strings.column(2) == pa.array([5]) + + def test_parse_sql_expr(df): plan1 = df.filter(df.parse_sql_expr("a > 2")).logical_plan() plan2 = df.filter(column("a") > literal(2)).logical_plan() @@ -388,9 +411,16 @@ def test_aggregate_tuple_aggs(df): assert result_tuple == result_list -def test_filter_string_unsupported(df): - with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): - df.filter("a > 1") +def test_filter_string_equivalent(df): + df1 = df.filter("a > 1").to_pydict() + df2 = df.filter(column("a") > literal(1)).to_pydict() + assert df1 == df2 + + +def test_filter_string_invalid(df): + with pytest.raises(Exception) as excinfo: + df.filter("this is not valid sql").collect() + assert "Expected Expr" not in str(excinfo.value) def test_drop(df):