Skip to content

Commit ff792be

Browse files
feat: Add support for quantile and ewm_mean in over context (#2774)
--------- Co-authored-by: Marco Edward Gorelli <[email protected]>
1 parent 93485ce commit ff792be

File tree

6 files changed

+130
-12
lines changed

6 files changed

+130
-12
lines changed

narwhals/_compliant/expr.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -740,8 +740,7 @@ def quantile(
740740
return self._reuse_series(
741741
"quantile",
742742
returns_scalar=True,
743-
quantile=quantile,
744-
interpolation=interpolation,
743+
scalar_kwargs={"quantile": quantile, "interpolation": interpolation},
745744
)
746745

747746
def head(self, n: int) -> Self:

narwhals/_compliant/typing.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,33 @@
2121
from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace
2222
from narwhals._compliant.series import CompliantSeries, EagerSeries
2323
from narwhals._compliant.window import WindowInputs
24-
from narwhals.typing import FillNullStrategy, NativeFrame, NativeSeries, RankMethod
24+
from narwhals.typing import (
25+
FillNullStrategy,
26+
NativeFrame,
27+
NativeSeries,
28+
RankMethod,
29+
RollingInterpolationMethod,
30+
)
2531

2632
class ScalarKwargs(TypedDict, total=False):
2733
"""Non-expressifiable args which we may need to reuse in `agg` or `over`."""
2834

35+
adjust: bool
36+
alpha: float | None
2937
center: int
38+
com: float | None
3039
ddof: int
3140
descending: bool
41+
half_life: float | None
42+
ignore_nulls: bool
43+
interpolation: RollingInterpolationMethod
3244
limit: int | None
3345
method: RankMethod
3446
min_samples: int
3547
n: int
48+
quantile: float
3649
reverse: bool
50+
span: float | None
3751
strategy: FillNullStrategy | None
3852
window_size: int
3953

@@ -157,7 +171,17 @@ class ScalarKwargs(TypedDict, total=False):
157171
"""A function evaluated with `over(partition_by=..., order_by=...)`."""
158172

159173
NarwhalsAggregation: TypeAlias = Literal[
160-
"sum", "mean", "median", "max", "min", "std", "var", "len", "n_unique", "count"
174+
"sum",
175+
"mean",
176+
"median",
177+
"max",
178+
"min",
179+
"std",
180+
"var",
181+
"len",
182+
"n_unique",
183+
"count",
184+
"quantile",
161185
]
162186
"""`Expr` methods we aim to support in `DepthTrackingGroupBy`.
163187

narwhals/_dask/group_by.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class DaskLazyGroupBy(DepthTrackingGroupBy["DaskLazyFrame", "DaskExpr", Aggregat
6464
"len": "size",
6565
"n_unique": n_unique,
6666
"count": "count",
67+
"quantile": "quantile",
6768
}
6869

6970
def __init__(

narwhals/_pandas_like/expr.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
"rank": "rank",
3838
"diff": "diff",
3939
"fill_null": "fillna",
40+
"quantile": "quantile",
41+
"ewm_mean": "mean",
4042
}
4143

4244

@@ -74,6 +76,31 @@ def window_kwargs_to_pandas_equivalent(
7476
assert "strategy" in kwargs # noqa: S101
7577
assert "limit" in kwargs # noqa: S101
7678
pandas_kwargs = {"strategy": kwargs["strategy"], "limit": kwargs["limit"]}
79+
elif function_name == "quantile":
80+
assert "quantile" in kwargs # noqa: S101
81+
assert "interpolation" in kwargs # noqa: S101
82+
pandas_kwargs = {
83+
"q": kwargs["quantile"],
84+
"interpolation": kwargs["interpolation"],
85+
}
86+
elif function_name.startswith("ewm_"):
87+
assert "com" in kwargs # noqa: S101
88+
assert "span" in kwargs # noqa: S101
89+
assert "half_life" in kwargs # noqa: S101
90+
assert "alpha" in kwargs # noqa: S101
91+
assert "adjust" in kwargs # noqa: S101
92+
assert "min_samples" in kwargs # noqa: S101
93+
assert "ignore_nulls" in kwargs # noqa: S101
94+
95+
pandas_kwargs = {
96+
"com": kwargs["com"],
97+
"span": kwargs["span"],
98+
"halflife": kwargs["half_life"],
99+
"alpha": kwargs["alpha"],
100+
"adjust": kwargs["adjust"],
101+
"min_periods": kwargs["min_samples"],
102+
"ignore_na": kwargs["ignore_nulls"],
103+
}
77104
else: # sum, len, ...
78105
pandas_kwargs = {}
79106
return pandas_kwargs
@@ -182,13 +209,15 @@ def ewm_mean(
182209
) -> Self:
183210
return self._reuse_series(
184211
"ewm_mean",
185-
com=com,
186-
span=span,
187-
half_life=half_life,
188-
alpha=alpha,
189-
adjust=adjust,
190-
min_samples=min_samples,
191-
ignore_nulls=ignore_nulls,
212+
scalar_kwargs={
213+
"com": com,
214+
"span": span,
215+
"half_life": half_life,
216+
"alpha": alpha,
217+
"adjust": adjust,
218+
"min_samples": min_samples,
219+
"ignore_nulls": ignore_nulls,
220+
},
192221
)
193222

194223
def over( # noqa: C901, PLR0915
@@ -232,7 +261,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
232261
function_name, self._scalar_kwargs
233262
)
234263

235-
def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, PLR0912
264+
def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, PLR0912, PLR0914, PLR0915
236265
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
237266
if function_name == "cum_count":
238267
plx = self.__narwhals_namespace__()
@@ -268,6 +297,18 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901,
268297
)
269298
else:
270299
res_native = getattr(rolling, pandas_function_name)()
300+
elif function_name.startswith("ewm"):
301+
if self._implementation.is_pandas() and (
302+
backend_version := self._backend_version
303+
) < (1, 2): # pragma: no cover
304+
msg = (
305+
"Exponentially weighted calculation is not available in over "
306+
f"context for pandas versions older than 1.2.0, found {backend_version}."
307+
)
308+
raise NotImplementedError(msg)
309+
ewm = grouped[list(output_names)].ewm(**pandas_kwargs)
310+
assert pandas_function_name is not None # help mypy # noqa: S101
311+
res_native = getattr(ewm, pandas_function_name)()
271312
elif function_name == "fill_null":
272313
assert "strategy" in self._scalar_kwargs # noqa: S101
273314
assert "limit" in self._scalar_kwargs # noqa: S101

narwhals/_pandas_like/group_by.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class PandasLikeGroupBy(EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr", st
2929
"len": "size",
3030
"n_unique": "nunique",
3131
"count": "count",
32+
"quantile": "quantile",
3233
}
3334

3435
def __init__(

tests/expr_and_series/over_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,3 +432,55 @@ def test_len_over_2369(constructor: Constructor, request: pytest.FixtureRequest)
432432
result = df.with_columns(a_len_per_group=nw.len().over("b")).sort("a")
433433
expected = {"a": [1, 2, 4], "b": ["x", "x", "y"], "a_len_per_group": [2, 2, 1]}
434434
assert_equal_data(result, expected)
435+
436+
437+
def test_over_quantile(constructor: Constructor, request: pytest.FixtureRequest) -> None:
438+
if "pyarrow_table" in str(constructor) or "pyspark" in str(constructor):
439+
request.applymarker(pytest.mark.xfail)
440+
441+
data = {"a": [1, 2, 3, 4, 5, 6], "b": ["x", "x", "x", "y", "y", "y"]}
442+
443+
quantile_expr = nw.col("a").quantile(quantile=0.5, interpolation="linear")
444+
native_frame = constructor(data)
445+
446+
if "dask" in str(constructor):
447+
native_frame = native_frame.repartition(npartitions=1) # type: ignore[union-attr]
448+
449+
result = (
450+
nw.from_native(native_frame)
451+
.with_columns(
452+
quantile_over_b=quantile_expr.over("b"), quantile_global=quantile_expr
453+
)
454+
.sort("a")
455+
)
456+
457+
expected = {
458+
**data,
459+
"quantile_over_b": [2, 2, 2, 5, 5, 5],
460+
"quantile_global": [3.5] * 6,
461+
}
462+
assert_equal_data(result, expected)
463+
464+
465+
def test_over_ewm_mean(
466+
constructor_eager: ConstructorEager, request: pytest.FixtureRequest
467+
) -> None:
468+
if "pyarrow_table" in str(constructor_eager) or "modin" in str(constructor_eager):
469+
request.applymarker(pytest.mark.xfail)
470+
if "pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 2):
471+
request.applymarker(pytest.mark.xfail(reason="too old, not implemented"))
472+
473+
data = {"a": [0.0, 1.0, 3.0, 5.0, 7.0, 7.5], "b": [1, 1, 1, 2, 2, 2]}
474+
475+
ewm_expr = nw.col("a").ewm_mean(com=1)
476+
result = (
477+
nw.from_native(constructor_eager(data))
478+
.with_columns(ewm_over_b=ewm_expr.over("b"), ewm_global=ewm_expr)
479+
.sort("a")
480+
)
481+
expected = {
482+
**data,
483+
"ewm_over_b": [0.0, 2 / 3, 2.0, 5.0, 6 + 1 / 3, 7.0],
484+
"ewm_global": [0.0, 2 / 3, 2.0, 3.6, 5.354838709677419, 6.444444444444445],
485+
}
486+
assert_equal_data(result, expected)

0 commit comments

Comments
 (0)