55import warnings
66from typing import TYPE_CHECKING
77from typing import Any
8+ from typing import ClassVar
89from typing import Iterator
10+ from typing import Mapping
11+ from typing import Sequence
912
13+ from narwhals ._compliant import EagerGroupBy
1014from narwhals ._expression_parsing import evaluate_output_names_and_aliases
1115from narwhals ._expression_parsing import is_elementary_expression
1216from narwhals ._pandas_like .utils import horizontal_concat
2226 from narwhals ._pandas_like .dataframe import PandasLikeDataFrame
2327 from narwhals ._pandas_like .expr import PandasLikeExpr
2428
25- AGGREGATIONS_TO_PANDAS_EQUIVALENT = {
26- "sum" : "sum" ,
27- "mean" : "mean" ,
28- "median" : "median" ,
29- "max" : "max" ,
30- "min" : "min" ,
31- "std" : "std" ,
32- "var" : "var" ,
33- "len" : "size" ,
34- "n_unique" : "nunique" ,
35- "count" : "count" ,
36- }
3729
30+ class PandasLikeGroupBy (EagerGroupBy ["PandasLikeDataFrame" , "PandasLikeExpr" ]):
31+ _NARWHALS_TO_NATIVE_AGGREGATIONS : ClassVar [Mapping [str , Any ]] = {
32+ "sum" : "sum" ,
33+ "mean" : "mean" ,
34+ "median" : "median" ,
35+ "max" : "max" ,
36+ "min" : "min" ,
37+ "std" : "std" ,
38+ "var" : "var" ,
39+ "len" : "size" ,
40+ "n_unique" : "nunique" ,
41+ "count" : "count" ,
42+ }
3843
39- class PandasLikeGroupBy :
4044 def __init__ (
41- self : Self , df : PandasLikeDataFrame , keys : list [str ], * , drop_null_keys : bool
45+ self : Self ,
46+ df : PandasLikeDataFrame ,
47+ keys : Sequence [str ],
48+ / ,
49+ * ,
50+ drop_null_keys : bool ,
4251 ) -> None :
43- self ._df = df
44- self ._keys = keys
52+ self ._compliant_frame = df
53+ self ._keys : list [ str ] = list ( keys )
4554 # Drop index to avoid potential collisions:
4655 # https://github.com/narwhals-dev/narwhals/issues/1907.
47- if set (df ._native_frame .index .names ).intersection (df .columns ):
48- native_frame = df ._native_frame .reset_index (drop = True )
56+ if set (df .native .index .names ).intersection (df .columns ):
57+ native_frame = df .native .reset_index (drop = True )
4958 else :
50- native_frame = df ._native_frame
59+ native_frame = df .native
5160 if (
52- self ._df ._implementation is Implementation .PANDAS
53- and self ._df ._backend_version < (1 , 1 )
61+ self .compliant ._implementation is Implementation .PANDAS
62+ and self .compliant ._backend_version < (1 , 1 )
5463 ): # pragma: no cover
5564 if (
5665 not drop_null_keys
57- and self ._df .simple_select (* self ._keys )._native_frame .isna ().any ().any ()
66+ and self .compliant .simple_select (* self ._keys ).native .isna ().any ().any ()
5867 ):
5968 msg = "Grouping by null values is not supported in pandas < 1.1.0"
6069 raise NotImplementedError (msg )
@@ -74,19 +83,21 @@ def __init__(
7483 )
7584
7685 def agg (self : Self , * exprs : PandasLikeExpr ) -> PandasLikeDataFrame : # noqa: PLR0915
77- implementation = self ._df ._implementation
78- backend_version = self ._df ._backend_version
86+ implementation = self .compliant ._implementation
87+ backend_version = self .compliant ._backend_version
7988 new_names : list [str ] = self ._keys .copy ()
8089
8190 all_aggs_are_simple = True
8291 for expr in exprs :
83- _ , aliases = evaluate_output_names_and_aliases (expr , self ._df , self ._keys )
92+ _ , aliases = evaluate_output_names_and_aliases (
93+ expr , self .compliant , self ._keys
94+ )
8495 new_names .extend (aliases )
8596
8697 if not (
8798 is_elementary_expression (expr )
8899 and re .sub (r"(\w+->)" , "" , expr ._function_name )
89- in AGGREGATIONS_TO_PANDAS_EQUIVALENT
100+ in self . _NARWHALS_TO_NATIVE_AGGREGATIONS
90101 ):
91102 all_aggs_are_simple = False
92103
@@ -111,11 +122,11 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR
111122 if all_aggs_are_simple :
112123 for expr in exprs :
113124 output_names , aliases = evaluate_output_names_and_aliases (
114- expr , self ._df , self ._keys
125+ expr , self .compliant , self ._keys
115126 )
116127 if expr ._depth == 0 :
117128 # e.g. agg(nw.len()) # noqa: ERA001
118- function_name = AGGREGATIONS_TO_PANDAS_EQUIVALENT .get (
129+ function_name = self . _NARWHALS_TO_NATIVE_AGGREGATIONS .get (
119130 expr ._function_name , expr ._function_name
120131 )
121132 simple_aggs_functions .add (function_name )
@@ -128,7 +139,7 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR
128139
129140 # e.g. agg(nw.mean('a')) # noqa: ERA001
130141 function_name = re .sub (r"(\w+->)" , "" , expr ._function_name )
131- function_name = AGGREGATIONS_TO_PANDAS_EQUIVALENT .get (
142+ function_name = self . _NARWHALS_TO_NATIVE_AGGREGATIONS .get (
132143 function_name , function_name
133144 )
134145
@@ -247,17 +258,17 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR
247258 )
248259 else :
249260 # No aggregation provided
250- result = self ._df .__native_namespace__ ().DataFrame (
261+ result = self .compliant .__native_namespace__ ().DataFrame (
251262 list (self ._grouped .groups .keys ()), columns = self ._keys
252263 )
253264 # Keep inplace=True to avoid making a redundant copy.
254265 # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files
255266 result .reset_index (inplace = True ) # noqa: PD002
256- return self ._df ._from_native_frame (
267+ return self .compliant ._from_native_frame (
257268 select_columns_by_name (result , new_names , backend_version , implementation )
258269 )
259270
260- if self ._df . _native_frame .empty :
271+ if self .compliant . native .empty :
261272 # Don't even attempt this, it's way too inconsistent across pandas versions.
262273 msg = (
263274 "No results for group-by aggregation.\n \n "
@@ -285,9 +296,9 @@ def func(df: Any) -> Any:
285296 out_group = []
286297 out_names = []
287298 for expr in exprs :
288- results_keys = expr (self ._df ._from_native_frame (df ))
299+ results_keys = expr (self .compliant ._from_native_frame (df ))
289300 for result_keys in results_keys :
290- out_group .append (result_keys ._native_series .iloc [0 ])
301+ out_group .append (result_keys .native .iloc [0 ])
291302 out_names .append (result_keys .name )
292303 return native_series_from_iterable (
293304 out_group ,
@@ -305,7 +316,7 @@ def func(df: Any) -> Any:
305316 # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files
306317 result_complex .reset_index (inplace = True ) # noqa: PD002
307318
308- return self ._df ._from_native_frame (
319+ return self .compliant ._from_native_frame (
309320 select_columns_by_name (
310321 result_complex , new_names , backend_version , implementation
311322 )
@@ -319,4 +330,4 @@ def __iter__(self: Self) -> Iterator[tuple[Any, PandasLikeDataFrame]]:
319330 category = FutureWarning ,
320331 )
321332 for key , group in self ._grouped :
322- yield (key , self ._df ._from_native_frame (group ))
333+ yield (key , self .compliant ._from_native_frame (group ))
0 commit comments