Skip to content

Commit 68d762a

Browse files
authored
feat: Allow mode(..., keep="any") (#3019)
1 parent 285815a commit 68d762a

File tree

17 files changed

+360
-33
lines changed

17 files changed

+360
-33
lines changed

docs/api-reference/typing.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Narwhals comes fully statically typed. In addition to `nw.DataFrame`, `nw.Expr`,
3030
- ConcatMethod
3131
- FillNullStrategy
3232
- JoinStrategy
33+
- ModeKeepStrategy
3334
- PivotAgg
3435
- RankMethod
3536
- RollingInterpolationMethod

narwhals/_arrow/series.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
FillNullStrategy,
6767
Into1DArray,
6868
IntoDType,
69+
ModeKeepStrategy,
6970
NonNestedLiteral,
7071
NumericLiteral,
7172
PythonLiteral,
@@ -856,16 +857,17 @@ def clip(
856857
def to_arrow(self) -> ArrayAny:
857858
return self.native.combine_chunks()
858859

859-
def mode(self) -> ArrowSeries:
860+
def mode(self, *, keep: ModeKeepStrategy) -> ArrowSeries:
860861
plx = self.__narwhals_namespace__()
861862
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name])
862863
counts = self.value_counts(
863864
name=col_token, normalize=False, sort=False, parallel=False
864865
)
865-
return counts.filter(
866+
result = counts.filter(
866867
plx.col(col_token)
867868
== plx.col(col_token).max().broadcast(kind=ExprKind.AGGREGATION)
868869
).get_column(self.name)
870+
return result.head(1) if keep == "any" else result
869871

870872
def is_finite(self) -> Self:
871873
return self._with_native(pc.is_finite(self.native))

narwhals/_compliant/column.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ClosedInterval,
2121
FillNullStrategy,
2222
IntoDType,
23+
ModeKeepStrategy,
2324
NonNestedLiteral,
2425
NumericLiteral,
2526
RankMethod,
@@ -174,7 +175,7 @@ def is_nan(self) -> Self: ...
174175
def is_null(self) -> Self: ...
175176
def is_unique(self) -> Self: ...
176177
def log(self, base: float) -> Self: ...
177-
def mode(self) -> Self: ...
178+
def mode(self, *, keep: ModeKeepStrategy) -> Self: ...
178179
def rank(self, method: RankMethod, *, descending: bool) -> Self: ...
179180
def replace_strict(
180181
self,

narwhals/_compliant/expr.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
ClosedInterval,
4646
FillNullStrategy,
4747
IntoDType,
48+
ModeKeepStrategy,
4849
NonNestedLiteral,
4950
NumericLiteral,
5051
RankMethod,
@@ -702,8 +703,8 @@ def len(self) -> Self:
702703
def gather_every(self, n: int, offset: int) -> Self:
703704
return self._reuse_series("gather_every", n=n, offset=offset)
704705

705-
def mode(self) -> Self:
706-
return self._reuse_series("mode")
706+
def mode(self, *, keep: ModeKeepStrategy) -> Self:
707+
return self._reuse_series("mode", scalar_kwargs={"keep": keep})
707708

708709
def is_finite(self) -> Self:
709710
return self._reuse_series("is_finite")

narwhals/_compliant/typing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from narwhals.typing import (
2525
FillNullStrategy,
2626
IntoLazyFrame,
27+
ModeKeepStrategy,
2728
NativeDataFrame,
2829
NativeFrame,
2930
NativeSeries,
@@ -43,6 +44,7 @@ class ScalarKwargs(TypedDict, total=False):
4344
half_life: float | None
4445
ignore_nulls: bool
4546
interpolation: RollingInterpolationMethod
47+
keep: ModeKeepStrategy
4648
limit: int | None
4749
method: RankMethod
4850
min_samples: int
@@ -180,6 +182,7 @@ class ScalarKwargs(TypedDict, total=False):
180182
"median",
181183
"max",
182184
"min",
185+
"mode",
183186
"std",
184187
"var",
185188
"len",

narwhals/_dask/expr.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from narwhals.typing import (
3535
FillNullStrategy,
3636
IntoDType,
37+
ModeKeepStrategy,
3738
NonNestedLiteral,
3839
NumericLiteral,
3940
RollingInterpolationMethod,
@@ -648,6 +649,14 @@ def sqrt(self) -> Self:
648649

649650
return self._with_callable(da.sqrt, "sqrt")
650651

652+
def mode(self, *, keep: ModeKeepStrategy) -> Self:
653+
def func(expr: dx.Series) -> dx.Series:
654+
_name = expr.name
655+
result = expr.to_frame().mode()[_name]
656+
return result.head(1) if keep == "any" else result
657+
658+
return self._with_callable(func, "mode", scalar_kwargs={"keep": keep})
659+
651660
@property
652661
def str(self) -> DaskExprStringNamespace:
653662
return DaskExprStringNamespace(self)
@@ -663,7 +672,6 @@ def dt(self) -> DaskExprDateTimeNamespace:
663672
gather_every: not_implemented = not_implemented()
664673
head: not_implemented = not_implemented()
665674
map_batches: not_implemented = not_implemented()
666-
mode: not_implemented = not_implemented()
667675
sample: not_implemented = not_implemented()
668676
rank: not_implemented = not_implemented()
669677
replace_strict: not_implemented = not_implemented()

narwhals/_pandas_like/group_by.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"mean",
4040
"median",
4141
"min",
42+
"mode",
4243
"nunique",
4344
"prod",
4445
"quantile",
@@ -115,6 +116,35 @@ def _getitem_aggs(
115116
result = ns._concat_horizontal(
116117
[ns.from_native(result_single).alias(name).native for name in names]
117118
)
119+
elif self.is_mode():
120+
compliant = group_by.compliant
121+
if (keep := self.kwargs.get("keep")) != "any": # pragma: no cover
122+
msg = (
123+
f"`Expr.mode(keep='{keep}')` is not implemented in group by context for "
124+
f"backend {compliant._implementation}\n\n"
125+
"Hint: Use `nw.col(...).mode(keep='any')` instead."
126+
)
127+
raise NotImplementedError(msg)
128+
129+
cols = list(names)
130+
native = compliant.native
131+
keys, kwargs = group_by._keys, group_by._kwargs
132+
133+
# Implementation based on the following suggestion:
134+
# https://github.com/pandas-dev/pandas/issues/19254#issuecomment-778661578
135+
ns = compliant.__narwhals_namespace__()
136+
result = ns._concat_horizontal(
137+
[
138+
native.groupby([*keys, col], **kwargs)
139+
.size()
140+
.sort_values(ascending=False)
141+
.reset_index(col)
142+
.groupby(keys, **kwargs)[col]
143+
.head(1)
144+
.sort_index()
145+
for col in cols
146+
]
147+
)
118148
else:
119149
select = names[0] if len(names) == 1 else list(names)
120150
result = self.native_agg()(group_by._grouped[select])
@@ -127,6 +157,9 @@ def _getitem_aggs(
127157
def is_len(self) -> bool:
128158
return self.leaf_name == "len"
129159

160+
def is_mode(self) -> bool:
161+
return self.leaf_name == "mode"
162+
130163
def is_top_level_function(self) -> bool:
131164
# e.g. `nw.len()`.
132165
return self.expr._depth == 0
@@ -158,6 +191,7 @@ class PandasLikeGroupBy(
158191
"median": "median",
159192
"max": "max",
160193
"min": "min",
194+
"mode": "mode",
161195
"std": "std",
162196
"var": "var",
163197
"len": "size",
@@ -176,6 +210,9 @@ class PandasLikeGroupBy(
176210
_output_key_names: list[str]
177211
"""Stores the **original** version of group keys."""
178212

213+
_kwargs: Mapping[str, bool]
214+
"""Stores keyword arguments for `DataFrame.groupby` other than `by`."""
215+
179216
@property
180217
def exclude(self) -> tuple[str, ...]:
181218
"""Group keys to ignore when expanding multi-output aggregations."""
@@ -200,13 +237,14 @@ def __init__(
200237
native = self.compliant.native
201238
if set(native.index.names).intersection(self.compliant.columns):
202239
native = native.reset_index(drop=True)
203-
self._grouped: NativeGroupBy = native.groupby(
204-
self._keys.copy(),
205-
sort=False,
206-
as_index=True,
207-
dropna=drop_null_keys,
208-
observed=True,
209-
)
240+
241+
self._kwargs = {
242+
"sort": False,
243+
"as_index": True,
244+
"dropna": drop_null_keys,
245+
"observed": True,
246+
}
247+
self._grouped: NativeGroupBy = native.groupby(self._keys.copy(), **self._kwargs)
210248

211249
def agg(self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame:
212250
all_aggs_are_simple = True

narwhals/_pandas_like/series.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
FillNullStrategy,
4848
Into1DArray,
4949
IntoDType,
50+
ModeKeepStrategy,
5051
NonNestedLiteral,
5152
NumericLiteral,
5253
RankMethod,
@@ -849,10 +850,10 @@ def to_arrow(self) -> pa.Array[Any]:
849850

850851
return pa.Array.from_pandas(self.native)
851852

852-
def mode(self) -> Self:
853+
def mode(self, *, keep: ModeKeepStrategy) -> Self:
853854
result = self.native.mode()
854855
result.name = self.name
855-
return self._with_native(result)
856+
return self._with_native(result.head(1) if keep == "any" else result)
856857

857858
def cum_count(self, *, reverse: bool) -> Self:
858859
not_na_series = ~self.native.isna()

narwhals/_polars/expr.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from narwhals._polars.dataframe import Method, PolarsDataFrame
2828
from narwhals._polars.namespace import PolarsNamespace
2929
from narwhals._utils import Version, _LimitedContext
30-
from narwhals.typing import IntoDType, NumericLiteral
30+
from narwhals.typing import IntoDType, ModeKeepStrategy, NumericLiteral
3131

3232

3333
class PolarsExpr:
@@ -280,6 +280,10 @@ def is_close(
280280
)
281281
return self._with_native(result)
282282

283+
def mode(self, *, keep: ModeKeepStrategy) -> Self:
284+
result = self.native.mode()
285+
return self._with_native(result.first() if keep == "any" else result)
286+
283287
@property
284288
def dt(self) -> PolarsExprDateTimeNamespace:
285289
return PolarsExprDateTimeNamespace(self)
@@ -366,7 +370,6 @@ def _eval_names_indices(indices: Sequence[int], /) -> EvalNames[PolarsDataFrame]
366370
mean: Method[Self]
367371
median: Method[Self]
368372
min: Method[Self]
369-
mode: Method[Self]
370373
n_unique: Method[Self]
371374
null_count: Method[Self]
372375
quantile: Method[Self]

narwhals/_polars/series.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from narwhals.typing import (
4242
Into1DArray,
4343
IntoDType,
44+
ModeKeepStrategy,
4445
MultiIndexSelector,
4546
NonNestedLiteral,
4647
NumericLiteral,
@@ -518,6 +519,10 @@ def is_close(
518519
)
519520
return self._with_native(result)
520521

522+
def mode(self, *, keep: ModeKeepStrategy) -> Self:
523+
result = self.native.mode()
524+
return self._with_native(result.head(1) if keep == "any" else result)
525+
521526
def hist_from_bins(
522527
self, bins: list[float], *, include_breakpoint: bool
523528
) -> PolarsDataFrame:
@@ -702,7 +707,6 @@ def struct(self) -> PolarsSeriesStructNamespace:
702707
max: Method[Any]
703708
mean: Method[float]
704709
min: Method[Any]
705-
mode: Method[Self]
706710
n_unique: Method[int]
707711
null_count: Method[int]
708712
quantile: Method[float]

0 commit comments

Comments
 (0)