|
1 | 1 | from __future__ import annotations |
| 2 | + |
| 3 | +from itertools import chain |
| 4 | +from typing import TYPE_CHECKING, Any, Protocol |
| 5 | + |
| 6 | +from narwhals._plan._expansion import prepare_projection |
| 7 | +from narwhals._plan._parse import parse_into_seq_of_expr_ir |
| 8 | +from narwhals._plan.common import replace, temp |
| 9 | +from narwhals._plan.compliant.typing import ( |
| 10 | + DataFrameT, |
| 11 | + EagerDataFrameT, |
| 12 | + FrameT_co, |
| 13 | + ResolverT_co, |
| 14 | +) |
| 15 | +from narwhals.exceptions import ComputeError |
| 16 | + |
| 17 | +if TYPE_CHECKING: |
| 18 | + from collections.abc import Iterator |
| 19 | + |
| 20 | + from typing_extensions import Self |
| 21 | + |
| 22 | + from narwhals._plan.expressions import ExprIR, NamedIR |
| 23 | + from narwhals._plan.schema import FrozenSchema, IntoFrozenSchema |
| 24 | + from narwhals._plan.typing import IntoExpr, OneOrIterable, Seq |
| 25 | + |
| 26 | + |
| 27 | +class CompliantGroupBy(Protocol[FrameT_co]): |
| 28 | + @property |
| 29 | + def compliant(self) -> FrameT_co: ... |
| 30 | + def agg(self, irs: Seq[NamedIR]) -> FrameT_co: ... |
| 31 | + |
| 32 | + |
| 33 | +class DataFrameGroupBy(CompliantGroupBy[DataFrameT], Protocol[DataFrameT]): |
| 34 | + _keys: Seq[NamedIR] |
| 35 | + _key_names: Seq[str] |
| 36 | + |
| 37 | + @classmethod |
| 38 | + def from_resolver( |
| 39 | + cls, df: DataFrameT, resolver: GroupByResolver, / |
| 40 | + ) -> DataFrameGroupBy[DataFrameT]: ... |
| 41 | + @classmethod |
| 42 | + def by_names( |
| 43 | + cls, df: DataFrameT, names: Seq[str], / |
| 44 | + ) -> DataFrameGroupBy[DataFrameT]: ... |
| 45 | + def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: ... |
| 46 | + @property |
| 47 | + def keys(self) -> Seq[NamedIR]: |
| 48 | + return self._keys |
| 49 | + |
| 50 | + @property |
| 51 | + def key_names(self) -> Seq[str]: |
| 52 | + if names := self._key_names: |
| 53 | + return names |
| 54 | + msg = "at least one key is required in a group_by operation" |
| 55 | + raise ComputeError(msg) |
| 56 | + |
| 57 | + |
| 58 | +class EagerDataFrameGroupBy(DataFrameGroupBy[EagerDataFrameT], Protocol[EagerDataFrameT]): |
| 59 | + _df: EagerDataFrameT |
| 60 | + _key_names: Seq[str] |
| 61 | + _key_names_original: Seq[str] |
| 62 | + _column_names_original: Seq[str] |
| 63 | + |
| 64 | + @classmethod |
| 65 | + def by_names(cls, df: EagerDataFrameT, names: Seq[str], /) -> Self: |
| 66 | + obj = cls.__new__(cls) |
| 67 | + obj._df = df |
| 68 | + obj._keys = () |
| 69 | + obj._key_names = names |
| 70 | + obj._key_names_original = () |
| 71 | + obj._column_names_original = tuple(df.columns) |
| 72 | + return obj |
| 73 | + |
| 74 | + @classmethod |
| 75 | + def from_resolver( |
| 76 | + cls, df: EagerDataFrameT, resolver: GroupByResolver, / |
| 77 | + ) -> EagerDataFrameGroupBy[EagerDataFrameT]: |
| 78 | + key_names = resolver.key_names |
| 79 | + if not resolver.requires_projection(): |
| 80 | + df = df.drop_nulls(key_names) if resolver._drop_null_keys else df |
| 81 | + return cls.by_names(df, key_names) |
| 82 | + obj = cls.__new__(cls) |
| 83 | + unique_names = temp.column_names(chain(key_names, df.columns)) |
| 84 | + safe_keys = tuple( |
| 85 | + replace(key, name=name) for key, name in zip(resolver.keys, unique_names) |
| 86 | + ) |
| 87 | + obj._df = df.with_columns(resolver._schema_in.with_columns_irs(safe_keys)) |
| 88 | + obj._keys = safe_keys |
| 89 | + obj._key_names = tuple(e.name for e in safe_keys) |
| 90 | + obj._key_names_original = key_names |
| 91 | + obj._column_names_original = resolver._schema_in.names |
| 92 | + return obj |
| 93 | + |
| 94 | + |
| 95 | +class Grouper(Protocol[ResolverT_co]): |
| 96 | + """`GroupBy` helper for collecting and forwarding `Expr`s for projection. |
| 97 | +
|
| 98 | + - Uses `Expr` everywhere (no need to duplicate layers) |
| 99 | + - Resolver only needs schema (neither needs a frame, but can use one to get `schema`) |
| 100 | + """ |
| 101 | + |
| 102 | + _keys: Seq[ExprIR] |
| 103 | + _aggs: Seq[ExprIR] |
| 104 | + _drop_null_keys: bool |
| 105 | + |
| 106 | + @classmethod |
| 107 | + def by(cls, *by: OneOrIterable[IntoExpr]) -> Self: |
| 108 | + obj = cls.__new__(cls) |
| 109 | + obj._keys = parse_into_seq_of_expr_ir(*by) |
| 110 | + return obj |
| 111 | + |
| 112 | + def agg(self, *aggs: OneOrIterable[IntoExpr]) -> Self: |
| 113 | + self._aggs = parse_into_seq_of_expr_ir(*aggs) |
| 114 | + return self |
| 115 | + |
| 116 | + @property |
| 117 | + def _resolver(self) -> type[ResolverT_co]: ... |
| 118 | + |
| 119 | + def resolve(self, context: IntoFrozenSchema, /) -> ResolverT_co: |
| 120 | + """Project keys and aggs in `context`, expanding all `Expr` -> `NamedIR`.""" |
| 121 | + return self._resolver.from_grouper(self, context) |
| 122 | + |
| 123 | + |
| 124 | +class GroupByResolver: |
| 125 | + """Narwhals-level `GroupBy` resolver.""" |
| 126 | + |
| 127 | + _schema_in: FrozenSchema |
| 128 | + _keys: Seq[NamedIR] |
| 129 | + _aggs: Seq[NamedIR] |
| 130 | + _key_names: Seq[str] |
| 131 | + _schema: FrozenSchema |
| 132 | + _drop_null_keys: bool |
| 133 | + |
| 134 | + @property |
| 135 | + def keys(self) -> Seq[NamedIR]: |
| 136 | + return self._keys |
| 137 | + |
| 138 | + @property |
| 139 | + def aggs(self) -> Seq[NamedIR]: |
| 140 | + return self._aggs |
| 141 | + |
| 142 | + @property |
| 143 | + def key_names(self) -> Seq[str]: |
| 144 | + if names := self._key_names: |
| 145 | + return names |
| 146 | + if keys := self.keys: |
| 147 | + return tuple(e.name for e in keys) |
| 148 | + msg = "at least one key is required in a group_by operation" |
| 149 | + raise ComputeError(msg) |
| 150 | + |
| 151 | + @property |
| 152 | + def schema(self) -> FrozenSchema: |
| 153 | + return self._schema |
| 154 | + |
| 155 | + def evaluate(self, frame: DataFrameT) -> DataFrameT: |
| 156 | + """Perform the `group_by` on `frame`.""" |
| 157 | + return frame.group_by_resolver(self).agg(self.aggs) |
| 158 | + |
| 159 | + @classmethod |
| 160 | + def from_grouper(cls, grouper: Grouper[Self], context: IntoFrozenSchema, /) -> Self: |
| 161 | + """Loosely based on [`resolve_group_by`]. |
| 162 | +
|
| 163 | + [`resolve_group_by`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1125-L1227 |
| 164 | + """ |
| 165 | + obj = cls.__new__(cls) |
| 166 | + keys, schema_in = prepare_projection(grouper._keys, schema=context) |
| 167 | + obj._keys, obj._schema_in = keys, schema_in |
| 168 | + obj._key_names = tuple(e.name for e in keys) |
| 169 | + obj._aggs, _ = prepare_projection(grouper._aggs, obj.key_names, schema=schema_in) |
| 170 | + obj._schema = schema_in.select(keys).merge(schema_in.select(obj._aggs)) |
| 171 | + obj._drop_null_keys = grouper._drop_null_keys |
| 172 | + return obj |
| 173 | + |
| 174 | + def requires_projection(self, *, allow_aliasing: bool = False) -> bool: |
| 175 | + """Return True is group keys contain anything that is not a column selection. |
| 176 | +
|
| 177 | + Notes: |
| 178 | + If False is returned, we can just use the resolved key names as a fast-path to group. |
| 179 | +
|
| 180 | + Arguments: |
| 181 | + allow_aliasing: If False (default), any aliasing is not considered to be column selection. |
| 182 | + """ |
| 183 | + if not all(key.is_column(allow_aliasing=allow_aliasing) for key in self.keys): |
| 184 | + if self._drop_null_keys: |
| 185 | + msg = "drop_null_keys cannot be True when keys contains Expr or Series" |
| 186 | + raise NotImplementedError(msg) |
| 187 | + return True |
| 188 | + return False |
| 189 | + |
| 190 | + |
| 191 | +class Resolved(GroupByResolver): |
| 192 | + """Compliant-level `GroupBy` resolver.""" |
| 193 | + |
| 194 | + _drop_null_keys: bool = False |
| 195 | + |
| 196 | + |
| 197 | +class Grouped(Grouper[Resolved]): |
| 198 | + """Compliant-level `GroupBy` helper.""" |
| 199 | + |
| 200 | + _keys: Seq[ExprIR] |
| 201 | + _aggs: Seq[ExprIR] |
| 202 | + _drop_null_keys: bool = False |
| 203 | + |
| 204 | + @property |
| 205 | + def _resolver(self) -> type[Resolved]: |
| 206 | + return Resolved |
0 commit comments