Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions daft/functions/agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,25 @@
from typing import Literal

from daft.daft import CountMode
from daft.expressions.expressions import Expression
from daft.expressions.expressions import Expression, col


def count(expr: Expression, mode: Literal["all", "valid", "null"] | CountMode = CountMode.Valid) -> Expression:
def count(
expr: Expression | None = None, mode: Literal["all", "valid", "null"] | CountMode = CountMode.Valid
) -> Expression:
"""Counts the number of values in the expression.

Args:
expr (Expression): The input expression to count values from.
expr (Expression | None): The input expression to count values from. If not provided, mode must be "all"
and count(*) semantics will be used.
mode: A string ("all", "valid", or "null") that represents whether to count all values, non-null (valid) values, or null values. Defaults to "valid".
"""
if isinstance(mode, str):
mode = CountMode.from_count_mode_str(mode)
if expr is None:
if mode != CountMode.All:
raise ValueError("count() without an expression only supports mode='all'.")
expr = col("*")
return Expression._from_pyexpr(expr._expr.count(mode))


Expand Down
37 changes: 37 additions & 0 deletions tests/dataframe/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,43 @@ def test_agg_with_literal_child(make_df, repartition_nparts, with_morsel_size):
assert res == {"count_lit": [3], "sum_lit": [3]}


@pytest.mark.parametrize("repartition_nparts", [1, 2, 4, 8])
def test_agg_count_mode_all_without_expr(make_df, repartition_nparts, with_morsel_size):
daft_df = make_df(
{"a": [1, 1, 2], "i": [0, 1, 2]},
repartition=repartition_nparts,
)

result = daft_df.agg(
daft.functions.count(mode="all").alias("count_all"),
)
assert result.to_pydict() == {"count_all": [3]}


@pytest.mark.parametrize("repartition_nparts", [1, 2, 4, 8])
def test_groupby_agg_count_mode_all_without_expr(make_df, repartition_nparts, with_morsel_size):
daft_df = make_df(
{"a": [1, 1, 2], "i": [0, 1, 2]},
repartition=repartition_nparts,
)

result = daft_df.groupby("a").agg(daft.functions.count(mode="all").alias("count_all")).sort("a")
assert result.to_pydict() == {"a": [1, 2], "count_all": [2, 1]}


def test_agg_count_mode_all_without_expr_empty_df(with_morsel_size):
daft_df = daft.from_pydict({"a": []}).with_column("a", col("a").cast(DataType.int64()))
result = daft_df.agg(daft.functions.count(mode="all").alias("count_all"))
assert result.to_pydict() == {"count_all": [0]}


def test_agg_count_without_expr_non_all_mode_raises():
with pytest.raises(ValueError, match="only supports mode='all'"):
daft.functions.count(mode="valid")
with pytest.raises(ValueError, match="only supports mode='all'"):
daft.functions.count(mode="null")


@pytest.mark.parametrize("repartition_nparts", [1, 2, 4, 8])
def test_groupby_agg_with_literal_child(make_df, repartition_nparts, with_morsel_size):
"""Test grouped aggregation of literal values.
Expand Down
Loading