Skip to content

Commit 3fc7495

Browse files
committed
refactor: Move *GroupBy + friends
1 parent 2791431 commit 3fc7495

File tree

5 files changed

+217
-196
lines changed

5 files changed

+217
-196
lines changed

narwhals/_plan/arrow/group_by.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from narwhals._plan._guards import is_agg_expr, is_function_expr
1010
from narwhals._plan.arrow import acero, functions as fn, options
1111
from narwhals._plan.common import dispatch_method_name, temp
12+
from narwhals._plan.compliant.group_by import EagerDataFrameGroupBy
1213
from narwhals._plan.expressions import aggregation as agg
13-
from narwhals._plan.protocols import EagerDataFrameGroupBy
1414
from narwhals._utils import Implementation
1515
from narwhals.exceptions import InvalidOperationError
1616

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,206 @@
11
from __future__ import annotations
2+
3+
from itertools import chain
4+
from typing import TYPE_CHECKING, Any, Protocol
5+
6+
from narwhals._plan._expansion import prepare_projection
7+
from narwhals._plan._parse import parse_into_seq_of_expr_ir
8+
from narwhals._plan.common import replace, temp
9+
from narwhals._plan.compliant.typing import (
10+
DataFrameT,
11+
EagerDataFrameT,
12+
FrameT_co,
13+
ResolverT_co,
14+
)
15+
from narwhals.exceptions import ComputeError
16+
17+
if TYPE_CHECKING:
18+
from collections.abc import Iterator
19+
20+
from typing_extensions import Self
21+
22+
from narwhals._plan.expressions import ExprIR, NamedIR
23+
from narwhals._plan.schema import FrozenSchema, IntoFrozenSchema
24+
from narwhals._plan.typing import IntoExpr, OneOrIterable, Seq
25+
26+
27+
class CompliantGroupBy(Protocol[FrameT_co]):
28+
@property
29+
def compliant(self) -> FrameT_co: ...
30+
def agg(self, irs: Seq[NamedIR]) -> FrameT_co: ...
31+
32+
33+
class DataFrameGroupBy(CompliantGroupBy[DataFrameT], Protocol[DataFrameT]):
34+
_keys: Seq[NamedIR]
35+
_key_names: Seq[str]
36+
37+
@classmethod
38+
def from_resolver(
39+
cls, df: DataFrameT, resolver: GroupByResolver, /
40+
) -> DataFrameGroupBy[DataFrameT]: ...
41+
@classmethod
42+
def by_names(
43+
cls, df: DataFrameT, names: Seq[str], /
44+
) -> DataFrameGroupBy[DataFrameT]: ...
45+
def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: ...
46+
@property
47+
def keys(self) -> Seq[NamedIR]:
48+
return self._keys
49+
50+
@property
51+
def key_names(self) -> Seq[str]:
52+
if names := self._key_names:
53+
return names
54+
msg = "at least one key is required in a group_by operation"
55+
raise ComputeError(msg)
56+
57+
58+
class EagerDataFrameGroupBy(DataFrameGroupBy[EagerDataFrameT], Protocol[EagerDataFrameT]):
59+
_df: EagerDataFrameT
60+
_key_names: Seq[str]
61+
_key_names_original: Seq[str]
62+
_column_names_original: Seq[str]
63+
64+
@classmethod
65+
def by_names(cls, df: EagerDataFrameT, names: Seq[str], /) -> Self:
66+
obj = cls.__new__(cls)
67+
obj._df = df
68+
obj._keys = ()
69+
obj._key_names = names
70+
obj._key_names_original = ()
71+
obj._column_names_original = tuple(df.columns)
72+
return obj
73+
74+
@classmethod
75+
def from_resolver(
76+
cls, df: EagerDataFrameT, resolver: GroupByResolver, /
77+
) -> EagerDataFrameGroupBy[EagerDataFrameT]:
78+
key_names = resolver.key_names
79+
if not resolver.requires_projection():
80+
df = df.drop_nulls(key_names) if resolver._drop_null_keys else df
81+
return cls.by_names(df, key_names)
82+
obj = cls.__new__(cls)
83+
unique_names = temp.column_names(chain(key_names, df.columns))
84+
safe_keys = tuple(
85+
replace(key, name=name) for key, name in zip(resolver.keys, unique_names)
86+
)
87+
obj._df = df.with_columns(resolver._schema_in.with_columns_irs(safe_keys))
88+
obj._keys = safe_keys
89+
obj._key_names = tuple(e.name for e in safe_keys)
90+
obj._key_names_original = key_names
91+
obj._column_names_original = resolver._schema_in.names
92+
return obj
93+
94+
95+
class Grouper(Protocol[ResolverT_co]):
96+
"""`GroupBy` helper for collecting and forwarding `Expr`s for projection.
97+
98+
- Uses `Expr` everywhere (no need to duplicate layers)
99+
- Resolver only needs schema (neither needs a frame, but can use one to get `schema`)
100+
"""
101+
102+
_keys: Seq[ExprIR]
103+
_aggs: Seq[ExprIR]
104+
_drop_null_keys: bool
105+
106+
@classmethod
107+
def by(cls, *by: OneOrIterable[IntoExpr]) -> Self:
108+
obj = cls.__new__(cls)
109+
obj._keys = parse_into_seq_of_expr_ir(*by)
110+
return obj
111+
112+
def agg(self, *aggs: OneOrIterable[IntoExpr]) -> Self:
113+
self._aggs = parse_into_seq_of_expr_ir(*aggs)
114+
return self
115+
116+
@property
117+
def _resolver(self) -> type[ResolverT_co]: ...
118+
119+
def resolve(self, context: IntoFrozenSchema, /) -> ResolverT_co:
120+
"""Project keys and aggs in `context`, expanding all `Expr` -> `NamedIR`."""
121+
return self._resolver.from_grouper(self, context)
122+
123+
124+
class GroupByResolver:
125+
"""Narwhals-level `GroupBy` resolver."""
126+
127+
_schema_in: FrozenSchema
128+
_keys: Seq[NamedIR]
129+
_aggs: Seq[NamedIR]
130+
_key_names: Seq[str]
131+
_schema: FrozenSchema
132+
_drop_null_keys: bool
133+
134+
@property
135+
def keys(self) -> Seq[NamedIR]:
136+
return self._keys
137+
138+
@property
139+
def aggs(self) -> Seq[NamedIR]:
140+
return self._aggs
141+
142+
@property
143+
def key_names(self) -> Seq[str]:
144+
if names := self._key_names:
145+
return names
146+
if keys := self.keys:
147+
return tuple(e.name for e in keys)
148+
msg = "at least one key is required in a group_by operation"
149+
raise ComputeError(msg)
150+
151+
@property
152+
def schema(self) -> FrozenSchema:
153+
return self._schema
154+
155+
def evaluate(self, frame: DataFrameT) -> DataFrameT:
156+
"""Perform the `group_by` on `frame`."""
157+
return frame.group_by_resolver(self).agg(self.aggs)
158+
159+
@classmethod
160+
def from_grouper(cls, grouper: Grouper[Self], context: IntoFrozenSchema, /) -> Self:
161+
"""Loosely based on [`resolve_group_by`].
162+
163+
[`resolve_group_by`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1125-L1227
164+
"""
165+
obj = cls.__new__(cls)
166+
keys, schema_in = prepare_projection(grouper._keys, schema=context)
167+
obj._keys, obj._schema_in = keys, schema_in
168+
obj._key_names = tuple(e.name for e in keys)
169+
obj._aggs, _ = prepare_projection(grouper._aggs, obj.key_names, schema=schema_in)
170+
obj._schema = schema_in.select(keys).merge(schema_in.select(obj._aggs))
171+
obj._drop_null_keys = grouper._drop_null_keys
172+
return obj
173+
174+
def requires_projection(self, *, allow_aliasing: bool = False) -> bool:
175+
"""Return True is group keys contain anything that is not a column selection.
176+
177+
Notes:
178+
If False is returned, we can just use the resolved key names as a fast-path to group.
179+
180+
Arguments:
181+
allow_aliasing: If False (default), any aliasing is not considered to be column selection.
182+
"""
183+
if not all(key.is_column(allow_aliasing=allow_aliasing) for key in self.keys):
184+
if self._drop_null_keys:
185+
msg = "drop_null_keys cannot be True when keys contains Expr or Series"
186+
raise NotImplementedError(msg)
187+
return True
188+
return False
189+
190+
191+
class Resolved(GroupByResolver):
192+
"""Compliant-level `GroupBy` resolver."""
193+
194+
_drop_null_keys: bool = False
195+
196+
197+
class Grouped(Grouper[Resolved]):
198+
"""Compliant-level `GroupBy` helper."""
199+
200+
_keys: Seq[ExprIR]
201+
_aggs: Seq[ExprIR]
202+
_drop_null_keys: bool = False
203+
204+
@property
205+
def _resolver(self) -> type[Resolved]:
206+
return Resolved

narwhals/_plan/compliant/typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
if TYPE_CHECKING:
88
from typing_extensions import TypeAlias
99

10+
from narwhals._plan.compliant.group_by import GroupByResolver
1011
from narwhals._plan.compliant.namespace import CompliantNamespace
1112
from narwhals._plan.compliant.series import CompliantSeries
1213
from narwhals._plan.protocols import (
@@ -18,7 +19,6 @@
1819
EagerExpr,
1920
EagerScalar,
2021
ExprDispatch,
21-
GroupByResolver,
2222
LazyExpr,
2323
LazyScalar,
2424
)

narwhals/_plan/group_by.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import TYPE_CHECKING, Any, Generic
44

55
from narwhals._plan._parse import parse_into_seq_of_expr_ir
6-
from narwhals._plan.protocols import GroupByResolver as Resolved, Grouper
6+
from narwhals._plan.compliant.group_by import GroupByResolver as Resolved, Grouper
77
from narwhals._plan.typing import DataFrameT
88

99
if TYPE_CHECKING:

0 commit comments

Comments
 (0)