Skip to content

Commit 9637662

Browse files
committed
refactor: Implement PandasLikeGroupBy?
- Having a hard time working out what is going on here - All I've changed is what the refs are named
1 parent 41ea47a commit 9637662

File tree

6 files changed

+59
-43
lines changed

6 files changed

+59
-43
lines changed

narwhals/_arrow/group_by.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
self,
4646
compliant_frame: ArrowDataFrame,
4747
keys: Sequence[str],
48+
/,
4849
*,
4950
drop_null_keys: bool,
5051
) -> None:

narwhals/_compliant/group_by.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
self,
3737
compliant_frame: CompliantFrameT_co,
3838
keys: Sequence[str],
39+
/,
3940
*,
4041
drop_null_keys: bool,
4142
) -> None: ...

narwhals/_dask/expr.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,11 @@ def over(
550550
order_by: Sequence[str] | None,
551551
) -> Self:
552552
# pandas is a required dependency of dask so it's safe to import this
553-
from narwhals._pandas_like.group_by import AGGREGATIONS_TO_PANDAS_EQUIVALENT
553+
from narwhals._pandas_like.group_by import PandasLikeGroupBy
554+
555+
AGGREGATIONS_TO_PANDAS_EQUIVALENT = ( # noqa: N806
556+
PandasLikeGroupBy._NARWHALS_TO_NATIVE_AGGREGATIONS
557+
)
554558

555559
if not partition_by:
556560
assert order_by is not None # help type checkers # noqa: S101

narwhals/_pandas_like/dataframe.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,7 @@ def collect(
580580
def group_by(self: Self, *keys: str, drop_null_keys: bool) -> PandasLikeGroupBy:
581581
from narwhals._pandas_like.group_by import PandasLikeGroupBy
582582

583-
return PandasLikeGroupBy(
584-
self,
585-
list(keys),
586-
drop_null_keys=drop_null_keys,
587-
)
583+
return PandasLikeGroupBy(self, keys, drop_null_keys=drop_null_keys)
588584

589585
def join(
590586
self: Self,

narwhals/_pandas_like/expr.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from narwhals._expression_parsing import ExprKind
1212
from narwhals._expression_parsing import evaluate_output_names_and_aliases
1313
from narwhals._expression_parsing import is_elementary_expression
14-
from narwhals._pandas_like.group_by import AGGREGATIONS_TO_PANDAS_EQUIVALENT
14+
from narwhals._pandas_like.group_by import PandasLikeGroupBy
1515
from narwhals._pandas_like.series import PandasLikeSeries
1616
from narwhals.exceptions import ColumnNotFoundError
1717
from narwhals.utils import generate_temporary_column_name
@@ -223,6 +223,9 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
223223
)
224224
raise NotImplementedError(msg)
225225
else:
226+
AGGREGATIONS_TO_PANDAS_EQUIVALENT = ( # noqa: N806
227+
PandasLikeGroupBy._NARWHALS_TO_NATIVE_AGGREGATIONS
228+
)
226229
function_name: str = re.sub(r"(\w+->)", "", self._function_name)
227230
pandas_function_name = WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT.get(
228231
function_name,

narwhals/_pandas_like/group_by.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55
import warnings
66
from typing import TYPE_CHECKING
77
from typing import Any
8+
from typing import ClassVar
89
from typing import Iterator
10+
from typing import Mapping
11+
from typing import Sequence
912

13+
from narwhals._compliant import EagerGroupBy
1014
from narwhals._expression_parsing import evaluate_output_names_and_aliases
1115
from narwhals._expression_parsing import is_elementary_expression
1216
from narwhals._pandas_like.utils import horizontal_concat
@@ -22,39 +26,44 @@
2226
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
2327
from narwhals._pandas_like.expr import PandasLikeExpr
2428

25-
AGGREGATIONS_TO_PANDAS_EQUIVALENT = {
26-
"sum": "sum",
27-
"mean": "mean",
28-
"median": "median",
29-
"max": "max",
30-
"min": "min",
31-
"std": "std",
32-
"var": "var",
33-
"len": "size",
34-
"n_unique": "nunique",
35-
"count": "count",
36-
}
3729

30+
class PandasLikeGroupBy(EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr"]):
31+
_NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Any]] = {
32+
"sum": "sum",
33+
"mean": "mean",
34+
"median": "median",
35+
"max": "max",
36+
"min": "min",
37+
"std": "std",
38+
"var": "var",
39+
"len": "size",
40+
"n_unique": "nunique",
41+
"count": "count",
42+
}
3843

39-
class PandasLikeGroupBy:
4044
def __init__(
41-
self: Self, df: PandasLikeDataFrame, keys: list[str], *, drop_null_keys: bool
45+
self: Self,
46+
df: PandasLikeDataFrame,
47+
keys: Sequence[str],
48+
/,
49+
*,
50+
drop_null_keys: bool,
4251
) -> None:
43-
self._df = df
44-
self._keys = keys
52+
self._compliant_frame = df
53+
self._keys: list[str] = list(keys)
4554
# Drop index to avoid potential collisions:
4655
# https://github.com/narwhals-dev/narwhals/issues/1907.
47-
if set(df._native_frame.index.names).intersection(df.columns):
48-
native_frame = df._native_frame.reset_index(drop=True)
56+
if set(df.native.index.names).intersection(df.columns):
57+
native_frame = df.native.reset_index(drop=True)
4958
else:
50-
native_frame = df._native_frame
59+
native_frame = df.native
5160
if (
52-
self._df._implementation is Implementation.PANDAS
53-
and self._df._backend_version < (1, 1)
61+
self.compliant._implementation is Implementation.PANDAS
62+
and self.compliant._backend_version < (1, 1)
5463
): # pragma: no cover
5564
if (
5665
not drop_null_keys
57-
and self._df.simple_select(*self._keys)._native_frame.isna().any().any()
66+
and self.compliant.simple_select(*self._keys).native.isna().any().any()
5867
):
5968
msg = "Grouping by null values is not supported in pandas < 1.1.0"
6069
raise NotImplementedError(msg)
@@ -74,19 +83,21 @@ def __init__(
7483
)
7584

7685
def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR0915
77-
implementation = self._df._implementation
78-
backend_version = self._df._backend_version
86+
implementation = self.compliant._implementation
87+
backend_version = self.compliant._backend_version
7988
new_names: list[str] = self._keys.copy()
8089

8190
all_aggs_are_simple = True
8291
for expr in exprs:
83-
_, aliases = evaluate_output_names_and_aliases(expr, self._df, self._keys)
92+
_, aliases = evaluate_output_names_and_aliases(
93+
expr, self.compliant, self._keys
94+
)
8495
new_names.extend(aliases)
8596

8697
if not (
8798
is_elementary_expression(expr)
8899
and re.sub(r"(\w+->)", "", expr._function_name)
89-
in AGGREGATIONS_TO_PANDAS_EQUIVALENT
100+
in self._NARWHALS_TO_NATIVE_AGGREGATIONS
90101
):
91102
all_aggs_are_simple = False
92103

@@ -111,11 +122,11 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR
111122
if all_aggs_are_simple:
112123
for expr in exprs:
113124
output_names, aliases = evaluate_output_names_and_aliases(
114-
expr, self._df, self._keys
125+
expr, self.compliant, self._keys
115126
)
116127
if expr._depth == 0:
117128
# e.g. agg(nw.len()) # noqa: ERA001
118-
function_name = AGGREGATIONS_TO_PANDAS_EQUIVALENT.get(
129+
function_name = self._NARWHALS_TO_NATIVE_AGGREGATIONS.get(
119130
expr._function_name, expr._function_name
120131
)
121132
simple_aggs_functions.add(function_name)
@@ -128,7 +139,7 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR
128139

129140
# e.g. agg(nw.mean('a')) # noqa: ERA001
130141
function_name = re.sub(r"(\w+->)", "", expr._function_name)
131-
function_name = AGGREGATIONS_TO_PANDAS_EQUIVALENT.get(
142+
function_name = self._NARWHALS_TO_NATIVE_AGGREGATIONS.get(
132143
function_name, function_name
133144
)
134145

@@ -247,17 +258,17 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR
247258
)
248259
else:
249260
# No aggregation provided
250-
result = self._df.__native_namespace__().DataFrame(
261+
result = self.compliant.__native_namespace__().DataFrame(
251262
list(self._grouped.groups.keys()), columns=self._keys
252263
)
253264
# Keep inplace=True to avoid making a redundant copy.
254265
# This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files
255266
result.reset_index(inplace=True) # noqa: PD002
256-
return self._df._from_native_frame(
267+
return self.compliant._from_native_frame(
257268
select_columns_by_name(result, new_names, backend_version, implementation)
258269
)
259270

260-
if self._df._native_frame.empty:
271+
if self.compliant.native.empty:
261272
# Don't even attempt this, it's way too inconsistent across pandas versions.
262273
msg = (
263274
"No results for group-by aggregation.\n\n"
@@ -285,9 +296,9 @@ def func(df: Any) -> Any:
285296
out_group = []
286297
out_names = []
287298
for expr in exprs:
288-
results_keys = expr(self._df._from_native_frame(df))
299+
results_keys = expr(self.compliant._from_native_frame(df))
289300
for result_keys in results_keys:
290-
out_group.append(result_keys._native_series.iloc[0])
301+
out_group.append(result_keys.native.iloc[0])
291302
out_names.append(result_keys.name)
292303
return native_series_from_iterable(
293304
out_group,
@@ -305,7 +316,7 @@ def func(df: Any) -> Any:
305316
# This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files
306317
result_complex.reset_index(inplace=True) # noqa: PD002
307318

308-
return self._df._from_native_frame(
319+
return self.compliant._from_native_frame(
309320
select_columns_by_name(
310321
result_complex, new_names, backend_version, implementation
311322
)
@@ -319,4 +330,4 @@ def __iter__(self: Self) -> Iterator[tuple[Any, PandasLikeDataFrame]]:
319330
category=FutureWarning,
320331
)
321332
for key, group in self._grouped:
322-
yield (key, self._df._from_native_frame(group))
333+
yield (key, self.compliant._from_native_frame(group))

0 commit comments

Comments
 (0)