diff --git a/daft/functions/agg.py b/daft/functions/agg.py index 73395b3814..9f20c5444d 100644 --- a/daft/functions/agg.py +++ b/daft/functions/agg.py @@ -5,18 +5,25 @@ from typing import Literal from daft.daft import CountMode -from daft.expressions.expressions import Expression +from daft.expressions.expressions import Expression, col -def count(expr: Expression, mode: Literal["all", "valid", "null"] | CountMode = CountMode.Valid) -> Expression: +def count( + expr: Expression | None = None, mode: Literal["all", "valid", "null"] | CountMode = CountMode.Valid +) -> Expression: """Counts the number of values in the expression. Args: - expr (Expression): The input expression to count values from. + expr (Expression | None): The input expression to count values from. If not provided, mode must be "all" + and count(*) semantics will be used. mode: A string ("all", "valid", or "null") that represents whether to count all values, non-null (valid) values, or null values. Defaults to "valid". """ if isinstance(mode, str): mode = CountMode.from_count_mode_str(mode) + if expr is None: + if mode != CountMode.All: + raise ValueError("count() without an expression only supports mode='all'.") + expr = col("*") return Expression._from_pyexpr(expr._expr.count(mode)) diff --git a/tests/dataframe/test_aggregations.py b/tests/dataframe/test_aggregations.py index 5299257886..d2359b2b95 100644 --- a/tests/dataframe/test_aggregations.py +++ b/tests/dataframe/test_aggregations.py @@ -740,6 +740,43 @@ def test_agg_with_literal_child(make_df, repartition_nparts, with_morsel_size): assert res == {"count_lit": [3], "sum_lit": [3]} +@pytest.mark.parametrize("repartition_nparts", [1, 2, 4, 8]) +def test_agg_count_mode_all_without_expr(make_df, repartition_nparts, with_morsel_size): + daft_df = make_df( + {"a": [1, 1, 2], "i": [0, 1, 2]}, + repartition=repartition_nparts, + ) + + result = daft_df.agg( + daft.functions.count(mode="all").alias("count_all"), + ) + assert result.to_pydict() == {"count_all": [3]} + + +@pytest.mark.parametrize("repartition_nparts", [1, 2, 4, 8]) +def test_groupby_agg_count_mode_all_without_expr(make_df, repartition_nparts, with_morsel_size): + daft_df = make_df( + {"a": [1, 1, 2], "i": [0, 1, 2]}, + repartition=repartition_nparts, + ) + + result = daft_df.groupby("a").agg(daft.functions.count(mode="all").alias("count_all")).sort("a") + assert result.to_pydict() == {"a": [1, 2], "count_all": [2, 1]} + + +def test_agg_count_mode_all_without_expr_empty_df(with_morsel_size): + daft_df = daft.from_pydict({"a": []}).with_column("a", col("a").cast(DataType.int64())) + result = daft_df.agg(daft.functions.count(mode="all").alias("count_all")) + assert result.to_pydict() == {"count_all": [0]} + + +def test_agg_count_without_expr_non_all_mode_raises(): + with pytest.raises(ValueError, match="only supports mode='all'"): + daft.functions.count(mode="valid") + with pytest.raises(ValueError, match="only supports mode='all'"): + daft.functions.count(mode="null") + + @pytest.mark.parametrize("repartition_nparts", [1, 2, 4, 8]) def test_groupby_agg_with_literal_child(make_df, repartition_nparts, with_morsel_size): """Test grouped aggregation of literal values.