55from typing import TYPE_CHECKING
66from typing import Any
77from typing import Callable
8+ from typing import ClassVar
89from typing import Mapping
910from typing import Sequence
1011
1112import dask .dataframe as dd
1213
14+ from narwhals ._compliant import CompliantGroupBy
1315from narwhals ._expression_parsing import evaluate_output_names_and_aliases
14- from narwhals ._expression_parsing import is_elementary_expression
1516
1617try :
1718 import dask .dataframe .dask_expr as dx
@@ -54,90 +55,54 @@ def std(ddof: int) -> _AggFn:
5455 return partial (_DaskGroupBy .std , ddof = ddof )
5556
5657
57- POLARS_TO_DASK_AGGREGATIONS : Mapping [str , Aggregation ] = {
58- "sum" : "sum" ,
59- "mean" : "mean" ,
60- "median" : "median" ,
61- "max" : "max" ,
62- "min" : "min" ,
63- "std" : std ,
64- "var" : var ,
65- "len" : "size" ,
66- "n_unique" : n_unique ,
67- "count" : "count" ,
68- }
58+ class DaskLazyGroupBy (CompliantGroupBy ["DaskLazyFrame" , "DaskExpr" ]):
59+ _NARWHALS_TO_NATIVE_AGGREGATIONS : ClassVar [Mapping [str , Aggregation ]] = {
60+ "sum" : "sum" ,
61+ "mean" : "mean" ,
62+ "median" : "median" ,
63+ "max" : "max" ,
64+ "min" : "min" ,
65+ "std" : std ,
66+ "var" : var ,
67+ "len" : "size" ,
68+ "n_unique" : n_unique ,
69+ "count" : "count" ,
70+ }
6971
70-
71- class DaskLazyGroupBy :
7272 def __init__ (
73- self : Self , df : DaskLazyFrame , keys : list [str ], * , drop_null_keys : bool
73+ self : Self , df : DaskLazyFrame , keys : Sequence [str ], / , * , drop_null_keys : bool
7474 ) -> None :
75- self ._df : DaskLazyFrame = df
76- self ._keys = keys
77- self ._grouped = self ._df ._native_frame .groupby (
78- list (self ._keys ),
79- dropna = drop_null_keys ,
80- observed = True ,
81- )
82-
83- def agg (
84- self : Self ,
85- * exprs : DaskExpr ,
86- ) -> DaskLazyFrame :
87- return agg_dask (
88- self ._df ,
89- self ._grouped ,
90- exprs ,
91- self ._keys ,
92- self ._from_native_frame ,
75+ self ._compliant_frame = df
76+ self ._keys : list [str ] = list (keys )
77+ self ._grouped = self .compliant .native .groupby (
78+ list (self ._keys ), dropna = drop_null_keys , observed = True
9379 )
9480
95- def _from_native_frame (self : Self , df : dd . DataFrame ) -> DaskLazyFrame :
81+ def agg (self : Self , * exprs : DaskExpr ) -> DaskLazyFrame :
9682 from narwhals ._dask .dataframe import DaskLazyFrame
9783
98- return DaskLazyFrame (
99- df ,
100- backend_version = self ._df ._backend_version ,
101- version = self ._df ._version ,
102- )
103-
104-
105- def agg_dask (
106- df : DaskLazyFrame ,
107- grouped : Any ,
108- exprs : Sequence [DaskExpr ],
109- keys : list [str ],
110- from_dataframe : Callable [[Any ], DaskLazyFrame ],
111- ) -> DaskLazyFrame :
112- """This should be the fastpath, but cuDF is too far behind to use it.
113-
114- - https://github.com/rapidsai/cudf/issues/15118
115- - https://github.com/rapidsai/cudf/issues/15084
116- """
117- if not exprs :
118- # No aggregation provided
119- return df .simple_select (* keys ).unique (subset = keys , keep = "any" )
120-
121- all_simple_aggs = True
122- for expr in exprs :
123- if not (
124- is_elementary_expression (expr )
125- and re .sub (r"(\w+->)" , "" , expr ._function_name ) in POLARS_TO_DASK_AGGREGATIONS
126- ):
127- all_simple_aggs = False
128- break
129-
130- if all_simple_aggs :
84+ if not exprs :
85+ # No aggregation provided
86+ return self .compliant .simple_select (* self ._keys ).unique (
87+ self ._keys , keep = "any"
88+ )
89+ self ._ensure_all_simple (exprs )
90+ # This should be the fastpath, but cuDF is too far behind to use it.
91+ # - https://github.com/rapidsai/cudf/issues/15118
92+ # - https://github.com/rapidsai/cudf/issues/15084
93+ POLARS_TO_DASK_AGGREGATIONS = self ._NARWHALS_TO_NATIVE_AGGREGATIONS # noqa: N806
13194 simple_aggregations : dict [str , tuple [str , Aggregation ]] = {}
13295 for expr in exprs :
133- output_names , aliases = evaluate_output_names_and_aliases (expr , df , keys )
96+ output_names , aliases = evaluate_output_names_and_aliases (
97+ expr , self .compliant , self ._keys
98+ )
13499 if expr ._depth == 0 :
135100 # e.g. agg(nw.len()) # noqa: ERA001
136101 function_name = POLARS_TO_DASK_AGGREGATIONS .get (
137102 expr ._function_name , expr ._function_name
138103 )
139104 simple_aggregations .update (
140- dict .fromkeys (aliases , (keys [0 ], function_name ))
105+ dict .fromkeys (aliases , (self . _keys [0 ], function_name ))
141106 )
142107 continue
143108
@@ -150,24 +115,12 @@ def agg_dask(
150115 if callable (agg_function )
151116 else agg_function
152117 )
153-
154118 simple_aggregations .update (
155- {
156- alias : (output_name , agg_function )
157- for alias , output_name in zip (aliases , output_names )
158- }
119+ (alias , (output_name , agg_function ))
120+ for alias , output_name in zip (aliases , output_names )
159121 )
160- result_simple = grouped .agg (** simple_aggregations )
161- return from_dataframe (result_simple .reset_index ())
162-
163- msg = (
164- "Non-trivial complex aggregation found.\n \n "
165- "Hint: you were probably trying to apply a non-elementary aggregation with a "
166- "dask dataframe.\n "
167- "Please rewrite your query such that group-by aggregations "
168- "are elementary. For example, instead of:\n \n "
169- " df.group_by('a').agg(nw.col('b').round(2).mean())\n \n "
170- "use:\n \n "
171- " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n \n "
172- )
173- raise ValueError (msg )
122+ return DaskLazyFrame (
123+ self ._grouped .agg (** simple_aggregations ).reset_index (),
124+ backend_version = self .compliant ._backend_version ,
125+ version = self .compliant ._version ,
126+ )
0 commit comments