Skip to content

Commit 4f1b172

Browse files
committed
Merge remote-tracking branch 'upstream/main' into from-numpy-2d-ns
2 parents 0e67a1e + 36c6d57 commit 4f1b172

File tree

5 files changed

+214
-41
lines changed

5 files changed

+214
-41
lines changed

narwhals/_compliant/expr.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,6 @@ class LazyExpr(
882882
sample: not_implemented = not_implemented()
883883
map_batches: not_implemented = not_implemented()
884884
ewm_mean: not_implemented = not_implemented()
885-
rolling_mean: not_implemented = not_implemented()
886885
rolling_var: not_implemented = not_implemented()
887886
rolling_std: not_implemented = not_implemented()
888887
gather_every: not_implemented = not_implemented()

narwhals/_dask/expr.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,16 @@ def rolling_sum(
392392
"rolling_sum",
393393
)
394394

395+
def rolling_mean(
396+
self: Self, window_size: int, *, min_samples: int, center: bool
397+
) -> Self:
398+
return self._from_call(
399+
lambda _input: _input.rolling(
400+
window=window_size, min_periods=min_samples, center=center
401+
).mean(),
402+
"rolling_mean",
403+
)
404+
395405
def sum(self: Self) -> Self:
396406
return self._from_call(lambda _input: _input.sum().to_series(), "sum")
397407

narwhals/_duckdb/expr.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,35 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression:
9898

9999
return func
100100

101+
def _rolling_window_func(
102+
self,
103+
*,
104+
func_name: Literal["sum", "mean", "std", "var"],
105+
center: bool,
106+
window_size: int,
107+
min_samples: int,
108+
) -> WindowFunction:
109+
if center:
110+
half = (window_size - 1) // 2
111+
remainder = (window_size - 1) % 2
112+
start = f"{half + remainder} preceding"
113+
end = f"{half} following"
114+
else:
115+
start = f"{window_size - 1} preceding"
116+
end = "current row"
117+
118+
def func(window_inputs: WindowInputs) -> duckdb.Expression:
119+
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
120+
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
121+
window = f"({partition_by_sql} {order_by_sql} rows between {start} and {end})"
122+
sql = (
123+
f"case when count({window_inputs.expr}) over {window} >= {min_samples}"
124+
f"then {func_name}({window_inputs.expr}) over {window} else null end"
125+
)
126+
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
127+
128+
return func
129+
101130
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
102131
if kind is ExprKind.LITERAL:
103132
return self
@@ -546,26 +575,24 @@ def cum_prod(self, *, reverse: bool) -> Self:
546575
)
547576

548577
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
549-
if center:
550-
half = (window_size - 1) // 2
551-
remainder = (window_size - 1) % 2
552-
start = f"{half + remainder} preceding"
553-
end = f"{half} following"
554-
else:
555-
start = f"{window_size - 1} preceding"
556-
end = "current row"
557-
558-
def func(window_inputs: WindowInputs) -> duckdb.Expression:
559-
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
560-
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
561-
window = f"({partition_by_sql} {order_by_sql} rows between {start} and {end})"
562-
sql = (
563-
f"case when count({window_inputs.expr}) over {window} >= {min_samples}"
564-
f"then sum({window_inputs.expr}) over {window} else null end"
578+
return self._with_window_function(
579+
self._rolling_window_func(
580+
func_name="sum",
581+
center=center,
582+
window_size=window_size,
583+
min_samples=min_samples,
565584
)
566-
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
585+
)
567586

568-
return self._with_window_function(func)
587+
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
588+
return self._with_window_function(
589+
self._rolling_window_func(
590+
func_name="mean",
591+
center=center,
592+
window_size=window_size,
593+
min_samples=min_samples,
594+
)
595+
)
569596

570597
def fill_null(
571598
self: Self, value: Self | Any, strategy: Any, limit: int | None

narwhals/_spark_like/expr.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,39 @@ def func(window_inputs: WindowInputs) -> Column:
159159

160160
return func
161161

162+
def _rolling_window_func(
163+
self,
164+
*,
165+
func_name: Literal["sum", "mean", "std", "var"],
166+
center: bool,
167+
window_size: int,
168+
min_samples: int,
169+
) -> WindowFunction:
170+
if center:
171+
half = (window_size - 1) // 2
172+
remainder = (window_size - 1) % 2
173+
start = self._Window().currentRow - half - remainder
174+
end = self._Window().currentRow + half
175+
else:
176+
start = self._Window().currentRow - window_size + 1
177+
end = self._Window().currentRow
178+
179+
def func(window_inputs: WindowInputs) -> Column:
180+
window = (
181+
self._Window()
182+
.partitionBy(list(window_inputs.partition_by))
183+
.orderBy(
184+
[self._F.col(x).asc_nulls_first() for x in window_inputs.order_by]
185+
)
186+
.rowsBetween(start, end)
187+
)
188+
return self._F.when(
189+
self._F.count(window_inputs.expr).over(window) >= min_samples,
190+
getattr(self._F, func_name)(window_inputs.expr).over(window),
191+
)
192+
193+
return func
194+
162195
@classmethod
163196
def from_column_names(
164197
cls: type[Self],
@@ -623,30 +656,24 @@ def _fill_null(_input: Column, value: Column) -> Column:
623656
return self._from_call(_fill_null, value=value)
624657

625658
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
626-
if center:
627-
half = (window_size - 1) // 2
628-
remainder = (window_size - 1) % 2
629-
start = self._Window().currentRow - half - remainder
630-
end = self._Window().currentRow + half
631-
else:
632-
start = self._Window().currentRow - window_size + 1
633-
end = self._Window().currentRow
634-
635-
def func(window_inputs: WindowInputs) -> Column:
636-
window = (
637-
self._Window()
638-
.partitionBy(list(window_inputs.partition_by))
639-
.orderBy(
640-
[self._F.col(x).asc_nulls_first() for x in window_inputs.order_by]
641-
)
642-
.rowsBetween(start, end)
643-
)
644-
return self._F.when(
645-
self._F.count(window_inputs.expr).over(window) >= min_samples,
646-
self._F.sum(window_inputs.expr).over(window),
659+
return self._with_window_function(
660+
self._rolling_window_func(
661+
func_name="sum",
662+
center=center,
663+
window_size=window_size,
664+
min_samples=min_samples,
647665
)
666+
)
648667

649-
return self._with_window_function(func)
668+
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
669+
return self._with_window_function(
670+
self._rolling_window_func(
671+
func_name="mean",
672+
center=center,
673+
window_size=window_size,
674+
min_samples=min_samples,
675+
)
676+
)
650677

651678
@property
652679
def str(self: Self) -> SparkLikeExprStringNamespace:

tests/expr_and_series/rolling_mean_test.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from hypothesis import given
1111

1212
import narwhals.stable.v1 as nw
13+
from tests.utils import DUCKDB_VERSION
1314
from tests.utils import PANDAS_VERSION
15+
from tests.utils import POLARS_VERSION
16+
from tests.utils import Constructor
1417
from tests.utils import ConstructorEager
1518
from tests.utils import assert_equal_data
1619

@@ -95,3 +98,110 @@ def test_rolling_mean_hypothesis(center: bool, values: list[float]) -> None: #
9598
)
9699
expected_dict = nw.from_native(expected, eager_only=True).to_dict(as_series=False)
97100
assert_equal_data(result, expected_dict)
101+
102+
103+
@pytest.mark.filterwarnings(
104+
"ignore:`Expr.rolling_mean` is being called from the stable API although considered an unstable feature."
105+
)
106+
@pytest.mark.parametrize(
107+
("expected_a", "window_size", "min_samples", "center"),
108+
[
109+
([None, None, 1.5, None, None, 5, 8.5], 2, None, False),
110+
([None, None, 1.5, None, None, 5, 8.5], 2, 2, False),
111+
([None, None, 1.5, 1.5, None, 5, 7.0], 3, 2, False),
112+
([1, None, 1.5, 1.5, 4, 5, 7], 3, 1, False),
113+
([1.5, 1, 1.5, 2, 5, 7, 8.5], 3, 1, True),
114+
([1.5, 1, 1.5, 1.5, 5, 7, 7], 4, 1, True),
115+
([1.5, 1.5, 1.5, 1.5, 7, 7, 7], 5, 1, True),
116+
],
117+
)
118+
def test_rolling_mean_expr_lazy_grouped(
119+
constructor: Constructor,
120+
expected_a: list[float],
121+
window_size: int,
122+
min_samples: int,
123+
request: pytest.FixtureRequest,
124+
*,
125+
center: bool,
126+
) -> None:
127+
if ("polars" in str(constructor) and POLARS_VERSION < (1, 10)) or (
128+
"duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)
129+
):
130+
pytest.skip()
131+
if "pandas" in str(constructor):
132+
pytest.skip()
133+
if any(x in str(constructor) for x in ("dask", "pyarrow_table")):
134+
request.applymarker(pytest.mark.xfail)
135+
if "cudf" in str(constructor) and center:
136+
# center is not implemented for offset-based windows
137+
request.applymarker(pytest.mark.xfail)
138+
if "modin" in str(constructor):
139+
# unreliable
140+
pytest.skip()
141+
data = {
142+
"a": [1, None, 2, None, 4, 6, 11],
143+
"g": [1, 1, 1, 1, 2, 2, 2],
144+
"b": [1, None, 2, 3, 4, 5, 6],
145+
"i": list(range(7)),
146+
}
147+
df = nw.from_native(constructor(data))
148+
result = (
149+
df.with_columns(
150+
nw.col("a")
151+
.rolling_mean(window_size, min_samples=min_samples, center=center)
152+
.over("g", order_by="b")
153+
)
154+
.sort("i")
155+
.select("a")
156+
)
157+
expected = {"a": expected_a}
158+
assert_equal_data(result, expected)
159+
160+
161+
@pytest.mark.filterwarnings(
162+
"ignore:`Expr.rolling_mean` is being called from the stable API although considered an unstable feature."
163+
)
164+
@pytest.mark.parametrize(
165+
("expected_a", "window_size", "min_samples", "center"),
166+
[
167+
([None, None, 1.5, None, None, 5, 8.5], 2, None, False),
168+
([None, None, 1.5, None, None, 5, 8.5], 2, 2, False),
169+
([None, None, 1.5, 1.5, 3, 5, 7], 3, 2, False),
170+
([1, None, 1.5, 1.5, 3, 5, 7], 3, 1, False),
171+
([1.5, 1, 1.5, 3, 5, 7, 8.5], 3, 1, True),
172+
([1.5, 1, 1.5, 2.3333333333333335, 4, 7, 7], 4, 1, True),
173+
([1.5, 1.5, 2.3333333333333335, 3.25, 5.75, 7.0, 7.0], 5, 1, True),
174+
],
175+
)
176+
def test_rolling_mean_expr_lazy_ungrouped(
177+
constructor: Constructor,
178+
expected_a: list[float],
179+
window_size: int,
180+
min_samples: int,
181+
*,
182+
center: bool,
183+
) -> None:
184+
if ("polars" in str(constructor) and POLARS_VERSION < (1, 10)) or (
185+
"duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)
186+
):
187+
pytest.skip()
188+
if "modin" in str(constructor):
189+
# unreliable
190+
pytest.skip()
191+
data = {
192+
"a": [1, None, 2, None, 4, 6, 11],
193+
"b": [1, None, 2, 3, 4, 5, 6],
194+
"i": list(range(7)),
195+
}
196+
df = nw.from_native(constructor(data))
197+
result = (
198+
df.with_columns(
199+
nw.col("a")
200+
.rolling_mean(window_size, min_samples=min_samples, center=center)
201+
.over(order_by="b")
202+
)
203+
.select("a", "i")
204+
.sort("i")
205+
)
206+
expected = {"a": expected_a, "i": list(range(7))}
207+
assert_equal_data(result, expected)

0 commit comments

Comments
 (0)