Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/user-guide/dataframe/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ DataFusion's DataFrame API offers a wide range of operations:
# Select with expressions
df = df.select(column("a") + column("b"), column("a") - column("b"))

# Filter rows
# Filter rows (expressions or SQL strings)
df = df.filter(column("age") > literal(25))
df = df.filter("age > 25")

# Add computed columns
df = df.with_column("full_name", column("first_name") + literal(" ") + column("last_name"))
Expand Down
22 changes: 14 additions & 8 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,31 +466,37 @@ def drop(self, *columns: str) -> DataFrame:

return DataFrame(self.df.drop(*normalized_columns))

def filter(self, *predicates: Expr) -> DataFrame:
def filter(self, *predicates: Expr | str) -> DataFrame:
"""Return a DataFrame for which ``predicate`` evaluates to ``True``.

Rows for which ``predicate`` evaluates to ``False`` or ``None`` are filtered
out. If more than one predicate is provided, these predicates will be
combined as a logical AND. Each ``predicate`` must be an
combined as a logical AND. Each ``predicate`` can be an
:class:`~datafusion.expr.Expr` created using helper functions such as
:func:`datafusion.col` or :func:`datafusion.lit`.
If more complex logic is required, see the logical operations in
:py:mod:`~datafusion.functions`.
:func:`datafusion.col` or :func:`datafusion.lit`, or a SQL expression string
that will be parsed against the DataFrame schema. If more complex logic is
required, see the logical operations in :py:mod:`~datafusion.functions`.

Example::

from datafusion import col, lit
df.filter(col("a") > lit(1))
df.filter("a > 1")

Args:
predicates: Predicate expression(s) to filter the DataFrame.
predicates: Predicate expression(s) or SQL strings to filter the DataFrame.

Returns:
DataFrame after filtering.
"""
df = self.df
for p in predicates:
df = df.filter(ensure_expr(p))
for predicate in predicates:
expr = (
self.parse_sql_expr(predicate)
if isinstance(predicate, str)
else predicate
)
df = df.filter(ensure_expr(expr))
return DataFrame(df)

def parse_sql_expr(self, expr: str) -> Expr:
Expand Down
36 changes: 33 additions & 3 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,29 @@ def test_filter(df):
assert result.column(2) == pa.array([5])


def test_filter_string_predicates(df):
df_str = df.filter("a > 2")
result = df_str.collect()[0]

assert result.column(0) == pa.array([3])
assert result.column(1) == pa.array([6])
assert result.column(2) == pa.array([8])

df_mixed = df.filter("a > 1", column("b") != literal(6))
result_mixed = df_mixed.collect()[0]

assert result_mixed.column(0) == pa.array([2])
assert result_mixed.column(1) == pa.array([5])
assert result_mixed.column(2) == pa.array([5])

df_strings = df.filter("a > 1", "b < 6")
result_strings = df_strings.collect()[0]

assert result_strings.column(0) == pa.array([2])
assert result_strings.column(1) == pa.array([5])
assert result_strings.column(2) == pa.array([5])


def test_parse_sql_expr(df):
plan1 = df.filter(df.parse_sql_expr("a > 2")).logical_plan()
plan2 = df.filter(column("a") > literal(2)).logical_plan()
Expand Down Expand Up @@ -388,9 +411,16 @@ def test_aggregate_tuple_aggs(df):
assert result_tuple == result_list


def test_filter_string_unsupported(df):
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
df.filter("a > 1")
def test_filter_string_equivalent(df):
df1 = df.filter("a > 1").to_pydict()
df2 = df.filter(column("a") > literal(1)).to_pydict()
assert df1 == df2


def test_filter_string_invalid(df):
with pytest.raises(Exception) as excinfo:
df.filter("this is not valid sql").collect()
assert "Expected Expr" not in str(excinfo.value)


def test_drop(df):
Expand Down