Skip to content

Commit 0d57240

Browse files
committed
Merge remote-tracking branch 'upstream/main' into compliant-package
2 parents 067252b + c223138 commit 0d57240

File tree

18 files changed

+457
-1226
lines changed

18 files changed

+457
-1226
lines changed

narwhals/_arrow/series.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ def rolling_sum(
897897
self: Self,
898898
window_size: int,
899899
*,
900-
min_samples: int | None,
900+
min_samples: int,
901901
center: bool,
902902
) -> Self:
903903
min_samples = min_samples if min_samples is not None else window_size
@@ -931,7 +931,7 @@ def rolling_mean(
931931
self: Self,
932932
window_size: int,
933933
*,
934-
min_samples: int | None,
934+
min_samples: int,
935935
center: bool,
936936
) -> Self:
937937
min_samples = min_samples if min_samples is not None else window_size
@@ -968,7 +968,7 @@ def rolling_var(
968968
self: Self,
969969
window_size: int,
970970
*,
971-
min_samples: int | None,
971+
min_samples: int,
972972
center: bool,
973973
ddof: int,
974974
) -> Self:
@@ -1021,7 +1021,7 @@ def rolling_std(
10211021
self: Self,
10221022
window_size: int,
10231023
*,
1024-
min_samples: int | None,
1024+
min_samples: int,
10251025
center: bool,
10261026
ddof: int,
10271027
) -> Self:

narwhals/_compliant/expr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def rolling_sum(
195195
self,
196196
window_size: int,
197197
*,
198-
min_samples: int | None,
198+
min_samples: int,
199199
center: bool,
200200
) -> Self: ...
201201

@@ -204,7 +204,7 @@ def rolling_mean(
204204
self,
205205
window_size: int,
206206
*,
207-
min_samples: int | None,
207+
min_samples: int,
208208
center: bool,
209209
) -> Self: ...
210210

@@ -213,7 +213,7 @@ def rolling_var(
213213
self,
214214
window_size: int,
215215
*,
216-
min_samples: int | None,
216+
min_samples: int,
217217
center: bool,
218218
ddof: int,
219219
) -> Self: ...
@@ -223,7 +223,7 @@ def rolling_std(
223223
self,
224224
window_size: int,
225225
*,
226-
min_samples: int | None,
226+
min_samples: int,
227227
center: bool,
228228
ddof: int,
229229
) -> Self: ...

narwhals/_dask/expr.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,16 @@ def cum_prod(self: Self, *, reverse: bool) -> Self:
381381

382382
return self._from_call(lambda _input: _input.cumprod(), "cum_prod")
383383

384+
def rolling_sum(
385+
self: Self, window_size: int, *, min_samples: int, center: bool
386+
) -> Self:
387+
return self._from_call(
388+
lambda _input: _input.rolling(
389+
window=window_size, min_periods=min_samples, center=center
390+
).sum(),
391+
"rolling_sum",
392+
)
393+
384394
def sum(self: Self) -> Self:
385395
return self._from_call(lambda _input: _input.sum().to_series(), "sum")
386396

@@ -566,7 +576,7 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
566576
except KeyError:
567577
# window functions are unsupported: https://github.com/dask/dask/issues/11806
568578
msg = (
569-
f"Unsupported function: {function_name} in `over` context.\n\n."
579+
f"Unsupported function: {function_name} in `over` context.\n\n"
570580
f"Supported functions are {', '.join(AGGREGATIONS_TO_PANDAS_EQUIVALENT)}\n"
571581
)
572582
raise NotImplementedError(msg) from None

narwhals/_pandas_like/expr.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
# Pandas cumcount counts nulls while Polars does not
3838
# So, instead of using "cumcount" we use "cumsum" on notna() to get the same result
3939
"cum_count": "cumsum",
40+
"rolling_sum": "sum",
4041
"shift": "shift",
4142
"rank": "rank",
4243
"diff": "diff",
@@ -58,6 +59,12 @@ def window_kwargs_to_pandas_equivalent(
5859
}
5960
elif function_name.startswith("cum_"): # Cumulative operation
6061
pandas_kwargs = {"skipna": True}
62+
elif function_name.startswith("rolling_"): # Rolling operation
63+
pandas_kwargs = {
64+
"min_periods": kwargs["min_samples"],
65+
"window": kwargs["window_size"],
66+
"center": kwargs["center"],
67+
}
6168
else: # e.g. std, var
6269
pandas_kwargs = kwargs
6370
return pandas_kwargs
@@ -220,11 +227,11 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
220227
raise NotImplementedError(msg)
221228
else:
222229
function_name: str = re.sub(r"(\w+->)", "", self._function_name)
223-
if pandas_function_name := WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT.get(
224-
function_name, AGGREGATIONS_TO_PANDAS_EQUIVALENT.get(function_name, None)
225-
):
226-
pass
227-
else:
230+
pandas_function_name = WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT.get(
231+
function_name,
232+
AGGREGATIONS_TO_PANDAS_EQUIVALENT.get(function_name),
233+
)
234+
if pandas_function_name is None:
228235
msg = (
229236
f"Unsupported function: {function_name} in `over` context.\n\n"
230237
f"Supported functions are {', '.join(WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT)}\n"
@@ -237,7 +244,6 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
237244

238245
def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
239246
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
240-
241247
if function_name == "cum_count":
242248
plx = self.__narwhals_namespace__()
243249
df = df.with_columns(~plx.col(*output_names).is_null())
@@ -260,9 +266,16 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
260266
elif reverse:
261267
columns = list(set(partition_by).union(output_names))
262268
df = df[columns][::-1]
263-
res_native = df._native_frame.groupby(partition_by)[
264-
list(output_names)
265-
].transform(pandas_function_name, **pandas_kwargs)
269+
if function_name.startswith("rolling"):
270+
rolling = df._native_frame.groupby(partition_by)[
271+
list(output_names)
272+
].rolling(**pandas_kwargs)
273+
assert pandas_function_name is not None # help mypy # noqa: S101
274+
res_native = getattr(rolling, pandas_function_name)()
275+
else:
276+
res_native = df._native_frame.groupby(partition_by)[
277+
list(output_names)
278+
].transform(pandas_function_name, **pandas_kwargs)
266279
result_frame = df._from_native_frame(res_native).rename(
267280
dict(zip(output_names, aliases))
268281
)
@@ -331,6 +344,18 @@ def cum_max(self: Self, *, reverse: bool) -> Self:
331344
def cum_prod(self: Self, *, reverse: bool) -> Self:
332345
return self._reuse_series("cum_prod", call_kwargs={"reverse": reverse})
333346

347+
def rolling_sum(
348+
self: Self, window_size: int, *, min_samples: int, center: bool
349+
) -> Self:
350+
return self._reuse_series(
351+
"rolling_sum",
352+
call_kwargs={
353+
"window_size": window_size,
354+
"min_samples": min_samples,
355+
"center": center,
356+
},
357+
)
358+
334359
def rank(
335360
self: Self,
336361
method: Literal["average", "min", "max", "dense", "ordinal"],

narwhals/_pandas_like/series.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,7 @@ def rolling_sum(
929929
self: Self,
930930
window_size: int,
931931
*,
932-
min_samples: int | None,
932+
min_samples: int,
933933
center: bool,
934934
) -> Self:
935935
result = self._native_series.rolling(
@@ -941,7 +941,7 @@ def rolling_mean(
941941
self: Self,
942942
window_size: int,
943943
*,
944-
min_samples: int | None,
944+
min_samples: int,
945945
center: bool,
946946
) -> Self:
947947
result = self._native_series.rolling(
@@ -953,7 +953,7 @@ def rolling_var(
953953
self: Self,
954954
window_size: int,
955955
*,
956-
min_samples: int | None,
956+
min_samples: int,
957957
center: bool,
958958
ddof: int,
959959
) -> Self:
@@ -966,7 +966,7 @@ def rolling_std(
966966
self: Self,
967967
window_size: int,
968968
*,
969-
min_samples: int | None,
969+
min_samples: int,
970970
center: bool,
971971
ddof: int,
972972
) -> Self:

narwhals/_polars/expr.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def rolling_var(
126126
self: Self,
127127
window_size: int,
128128
*,
129-
min_samples: int | None,
129+
min_samples: int,
130130
center: bool,
131131
ddof: int,
132132
) -> Self:
@@ -152,7 +152,7 @@ def rolling_std(
152152
self: Self,
153153
window_size: int,
154154
*,
155-
min_samples: int | None,
155+
min_samples: int,
156156
center: bool,
157157
ddof: int,
158158
) -> Self:
@@ -175,11 +175,7 @@ def rolling_std(
175175
)
176176

177177
def rolling_sum(
178-
self: Self,
179-
window_size: int,
180-
*,
181-
min_samples: int | None,
182-
center: bool,
178+
self: Self, window_size: int, *, min_samples: int, center: bool
183179
) -> Self:
184180
extra_kwargs = (
185181
{"min_periods": min_samples}
@@ -199,7 +195,7 @@ def rolling_mean(
199195
self: Self,
200196
window_size: int,
201197
*,
202-
min_samples: int | None,
198+
min_samples: int,
203199
center: bool,
204200
) -> Self:
205201
extra_kwargs = (

narwhals/_polars/series.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def rolling_var(
321321
self: Self,
322322
window_size: int,
323323
*,
324-
min_samples: int | None,
324+
min_samples: int,
325325
center: bool,
326326
ddof: int,
327327
) -> Self:
@@ -348,7 +348,7 @@ def rolling_std(
348348
self: Self,
349349
window_size: int,
350350
*,
351-
min_samples: int | None,
351+
min_samples: int,
352352
center: bool,
353353
ddof: int,
354354
) -> Self:
@@ -375,7 +375,7 @@ def rolling_sum(
375375
self: Self,
376376
window_size: int,
377377
*,
378-
min_samples: int | None,
378+
min_samples: int,
379379
center: bool,
380380
) -> Self:
381381
extra_kwargs = (
@@ -396,7 +396,7 @@ def rolling_mean(
396396
self: Self,
397397
window_size: int,
398398
*,
399-
min_samples: int | None,
399+
min_samples: int,
400400
center: bool,
401401
) -> Self:
402402
extra_kwargs = (

narwhals/_spark_like/expr.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,12 +569,38 @@ def func(
569569
self._Window()
570570
.partitionBy(list(partition_by))
571571
.orderBy(order_by_cols)
572-
.rangeBetween(self._Window().unboundedPreceding, 0)
572+
.rowsBetween(self._Window().unboundedPreceding, 0)
573573
)
574574
return self._F.sum(_input).over(window)
575575

576576
return self._with_window_function(func)
577577

578+
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
579+
if center:
580+
half = (window_size - 1) // 2
581+
remainder = (window_size - 1) % 2
582+
start = self._Window().currentRow - half - remainder
583+
end = self._Window().currentRow + half
584+
else:
585+
start = self._Window().currentRow - window_size + 1
586+
end = self._Window().currentRow
587+
588+
def func(
589+
_input: Column, partition_by: Sequence[str], order_by: Sequence[str]
590+
) -> Column:
591+
window = (
592+
self._Window()
593+
.partitionBy(list(partition_by))
594+
.orderBy([self._F.col(x).asc_nulls_first() for x in order_by])
595+
.rowsBetween(start, end)
596+
)
597+
return self._F.when(
598+
self._F.count(_input).over(window) >= min_samples,
599+
self._F.sum(_input).over(window),
600+
)
601+
602+
return self._with_window_function(func)
603+
578604
@property
579605
def str(self: Self) -> SparkLikeExprStringNamespace:
580606
return SparkLikeExprStringNamespace(self)

narwhals/expr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2191,14 +2191,14 @@ def rolling_sum(
21912191
|3 4.0 6.0|
21922192
└─────────────────────┘
21932193
"""
2194-
window_size, min_samples = _validate_rolling_arguments(
2194+
window_size, min_samples_int = _validate_rolling_arguments(
21952195
window_size=window_size, min_samples=min_samples
21962196
)
21972197

21982198
return self.__class__(
21992199
lambda plx: self._to_compliant_expr(plx).rolling_sum(
22002200
window_size=window_size,
2201-
min_samples=min_samples,
2201+
min_samples=min_samples_int,
22022202
center=center,
22032203
),
22042204
self._metadata.with_kind_and_extra_open_window(ExprKind.WINDOW),

0 commit comments

Comments
 (0)