Skip to content

Commit 810fa20

Browse files
committed
refactor: Implement DaskLazyGroupBy
Much happier with this than the `pandas` one
1 parent 9637662 commit 810fa20

File tree

2 files changed

+43
-90
lines changed

2 files changed

+43
-90
lines changed

narwhals/_dask/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def join_asof(
388388
def group_by(self: Self, *by: str, drop_null_keys: bool) -> DaskLazyGroupBy:
389389
from narwhals._dask.group_by import DaskLazyGroupBy
390390

391-
return DaskLazyGroupBy(self, list(by), drop_null_keys=drop_null_keys)
391+
return DaskLazyGroupBy(self, by, drop_null_keys=drop_null_keys)
392392

393393
def tail(self: Self, n: int) -> Self: # pragma: no cover
394394
native_frame = self._native_frame

narwhals/_dask/group_by.py

Lines changed: 42 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from typing import TYPE_CHECKING
66
from typing import Any
77
from typing import Callable
8+
from typing import ClassVar
89
from typing import Mapping
910
from typing import Sequence
1011

1112
import dask.dataframe as dd
1213

14+
from narwhals._compliant import CompliantGroupBy
1315
from narwhals._expression_parsing import evaluate_output_names_and_aliases
14-
from narwhals._expression_parsing import is_elementary_expression
1516

1617
try:
1718
import dask.dataframe.dask_expr as dx
@@ -54,90 +55,54 @@ def std(ddof: int) -> _AggFn:
5455
return partial(_DaskGroupBy.std, ddof=ddof)
5556

5657

57-
POLARS_TO_DASK_AGGREGATIONS: Mapping[str, Aggregation] = {
58-
"sum": "sum",
59-
"mean": "mean",
60-
"median": "median",
61-
"max": "max",
62-
"min": "min",
63-
"std": std,
64-
"var": var,
65-
"len": "size",
66-
"n_unique": n_unique,
67-
"count": "count",
68-
}
58+
class DaskLazyGroupBy(CompliantGroupBy["DaskLazyFrame", "DaskExpr"]):
59+
_NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Aggregation]] = {
60+
"sum": "sum",
61+
"mean": "mean",
62+
"median": "median",
63+
"max": "max",
64+
"min": "min",
65+
"std": std,
66+
"var": var,
67+
"len": "size",
68+
"n_unique": n_unique,
69+
"count": "count",
70+
}
6971

70-
71-
class DaskLazyGroupBy:
7272
def __init__(
73-
self: Self, df: DaskLazyFrame, keys: list[str], *, drop_null_keys: bool
73+
self: Self, df: DaskLazyFrame, keys: Sequence[str], /, *, drop_null_keys: bool
7474
) -> None:
75-
self._df: DaskLazyFrame = df
76-
self._keys = keys
77-
self._grouped = self._df._native_frame.groupby(
78-
list(self._keys),
79-
dropna=drop_null_keys,
80-
observed=True,
81-
)
82-
83-
def agg(
84-
self: Self,
85-
*exprs: DaskExpr,
86-
) -> DaskLazyFrame:
87-
return agg_dask(
88-
self._df,
89-
self._grouped,
90-
exprs,
91-
self._keys,
92-
self._from_native_frame,
75+
self._compliant_frame = df
76+
self._keys: list[str] = list(keys)
77+
self._grouped = self.compliant.native.groupby(
78+
list(self._keys), dropna=drop_null_keys, observed=True
9379
)
9480

95-
def _from_native_frame(self: Self, df: dd.DataFrame) -> DaskLazyFrame:
81+
def agg(self: Self, *exprs: DaskExpr) -> DaskLazyFrame:
9682
from narwhals._dask.dataframe import DaskLazyFrame
9783

98-
return DaskLazyFrame(
99-
df,
100-
backend_version=self._df._backend_version,
101-
version=self._df._version,
102-
)
103-
104-
105-
def agg_dask(
106-
df: DaskLazyFrame,
107-
grouped: Any,
108-
exprs: Sequence[DaskExpr],
109-
keys: list[str],
110-
from_dataframe: Callable[[Any], DaskLazyFrame],
111-
) -> DaskLazyFrame:
112-
"""This should be the fastpath, but cuDF is too far behind to use it.
113-
114-
- https://github.com/rapidsai/cudf/issues/15118
115-
- https://github.com/rapidsai/cudf/issues/15084
116-
"""
117-
if not exprs:
118-
# No aggregation provided
119-
return df.simple_select(*keys).unique(subset=keys, keep="any")
120-
121-
all_simple_aggs = True
122-
for expr in exprs:
123-
if not (
124-
is_elementary_expression(expr)
125-
and re.sub(r"(\w+->)", "", expr._function_name) in POLARS_TO_DASK_AGGREGATIONS
126-
):
127-
all_simple_aggs = False
128-
break
129-
130-
if all_simple_aggs:
84+
if not exprs:
85+
# No aggregation provided
86+
return self.compliant.simple_select(*self._keys).unique(
87+
self._keys, keep="any"
88+
)
89+
self._ensure_all_simple(exprs)
90+
# This should be the fastpath, but cuDF is too far behind to use it.
91+
# - https://github.com/rapidsai/cudf/issues/15118
92+
# - https://github.com/rapidsai/cudf/issues/15084
93+
POLARS_TO_DASK_AGGREGATIONS = self._NARWHALS_TO_NATIVE_AGGREGATIONS # noqa: N806
13194
simple_aggregations: dict[str, tuple[str, Aggregation]] = {}
13295
for expr in exprs:
133-
output_names, aliases = evaluate_output_names_and_aliases(expr, df, keys)
96+
output_names, aliases = evaluate_output_names_and_aliases(
97+
expr, self.compliant, self._keys
98+
)
13499
if expr._depth == 0:
135100
# e.g. agg(nw.len()) # noqa: ERA001
136101
function_name = POLARS_TO_DASK_AGGREGATIONS.get(
137102
expr._function_name, expr._function_name
138103
)
139104
simple_aggregations.update(
140-
dict.fromkeys(aliases, (keys[0], function_name))
105+
dict.fromkeys(aliases, (self._keys[0], function_name))
141106
)
142107
continue
143108

@@ -150,24 +115,12 @@ def agg_dask(
150115
if callable(agg_function)
151116
else agg_function
152117
)
153-
154118
simple_aggregations.update(
155-
{
156-
alias: (output_name, agg_function)
157-
for alias, output_name in zip(aliases, output_names)
158-
}
119+
(alias, (output_name, agg_function))
120+
for alias, output_name in zip(aliases, output_names)
159121
)
160-
result_simple = grouped.agg(**simple_aggregations)
161-
return from_dataframe(result_simple.reset_index())
162-
163-
msg = (
164-
"Non-trivial complex aggregation found.\n\n"
165-
"Hint: you were probably trying to apply a non-elementary aggregation with a "
166-
"dask dataframe.\n"
167-
"Please rewrite your query such that group-by aggregations "
168-
"are elementary. For example, instead of:\n\n"
169-
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
170-
"use:\n\n"
171-
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
172-
)
173-
raise ValueError(msg)
122+
return DaskLazyFrame(
123+
self._grouped.agg(**simple_aggregations).reset_index(),
124+
backend_version=self.compliant._backend_version,
125+
version=self.compliant._version,
126+
)

0 commit comments

Comments
 (0)