Skip to content

Commit 730f4b7

Browse files
authored
feat: add rolling_sum for sqlframe and pyspark (#2168)
1 parent 944f6da commit 730f4b7

File tree

13 files changed

+206
-62
lines changed

13 files changed

+206
-62
lines changed

narwhals/_arrow/expr.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -545,11 +545,7 @@ def cum_prod(self: Self, *, reverse: bool) -> Self:
545545
return reuse_series_implementation(self, "cum_prod", reverse=reverse)
546546

547547
def rolling_sum(
548-
self: Self,
549-
window_size: int,
550-
*,
551-
min_samples: int | None,
552-
center: bool,
548+
self: Self, window_size: int, *, min_samples: int, center: bool
553549
) -> Self:
554550
return reuse_series_implementation(
555551
self,
@@ -563,7 +559,7 @@ def rolling_mean(
563559
self: Self,
564560
window_size: int,
565561
*,
566-
min_samples: int | None,
562+
min_samples: int,
567563
center: bool,
568564
) -> Self:
569565
return reuse_series_implementation(
@@ -578,7 +574,7 @@ def rolling_var(
578574
self: Self,
579575
window_size: int,
580576
*,
581-
min_samples: int | None,
577+
min_samples: int,
582578
center: bool,
583579
ddof: int,
584580
) -> Self:
@@ -595,7 +591,7 @@ def rolling_std(
595591
self: Self,
596592
window_size: int,
597593
*,
598-
min_samples: int | None,
594+
min_samples: int,
599595
center: bool,
600596
ddof: int,
601597
) -> Self:

narwhals/_arrow/series.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,7 @@ def rolling_sum(
898898
self: Self,
899899
window_size: int,
900900
*,
901-
min_samples: int | None,
901+
min_samples: int,
902902
center: bool,
903903
) -> Self:
904904
min_samples = min_samples if min_samples is not None else window_size
@@ -932,7 +932,7 @@ def rolling_mean(
932932
self: Self,
933933
window_size: int,
934934
*,
935-
min_samples: int | None,
935+
min_samples: int,
936936
center: bool,
937937
) -> Self:
938938
min_samples = min_samples if min_samples is not None else window_size
@@ -969,7 +969,7 @@ def rolling_var(
969969
self: Self,
970970
window_size: int,
971971
*,
972-
min_samples: int | None,
972+
min_samples: int,
973973
center: bool,
974974
ddof: int,
975975
) -> Self:
@@ -1022,7 +1022,7 @@ def rolling_std(
10221022
self: Self,
10231023
window_size: int,
10241024
*,
1025-
min_samples: int | None,
1025+
min_samples: int,
10261026
center: bool,
10271027
ddof: int,
10281028
) -> Self:

narwhals/_dask/expr.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,16 @@ def cum_prod(self: Self, *, reverse: bool) -> Self:
381381

382382
return self._from_call(lambda _input: _input.cumprod(), "cum_prod")
383383

384+
def rolling_sum(
385+
self: Self, window_size: int, *, min_samples: int, center: bool
386+
) -> Self:
387+
return self._from_call(
388+
lambda _input: _input.rolling(
389+
window=window_size, min_periods=min_samples, center=center
390+
).sum(),
391+
"rolling_sum",
392+
)
393+
384394
def sum(self: Self) -> Self:
385395
return self._from_call(lambda _input: _input.sum().to_series(), "sum")
386396

@@ -566,7 +576,7 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
566576
except KeyError:
567577
# window functions are unsupported: https://github.com/dask/dask/issues/11806
568578
msg = (
569-
f"Unsupported function: {function_name} in `over` context.\n\n."
579+
f"Unsupported function: {function_name} in `over` context.\n\n"
570580
f"Supported functions are {', '.join(AGGREGATIONS_TO_PANDAS_EQUIVALENT)}\n"
571581
)
572582
raise NotImplementedError(msg) from None
@@ -634,7 +644,6 @@ def name(self: Self) -> DaskExprNameNamespace:
634644
sample = not_implemented()
635645
map_batches = not_implemented()
636646
ewm_mean = not_implemented()
637-
rolling_sum = not_implemented()
638647
rolling_mean = not_implemented()
639648
rolling_var = not_implemented()
640649
rolling_std = not_implemented()

narwhals/_pandas_like/expr.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
# Pandas cumcount counts nulls while Polars does not
4444
# So, instead of using "cumcount" we use "cumsum" on notna() to get the same result
4545
"cum_count": "cumsum",
46+
"rolling_sum": "sum",
4647
"shift": "shift",
4748
"rank": "rank",
4849
"diff": "diff",
@@ -64,6 +65,12 @@ def window_kwargs_to_pandas_equivalent(
6465
}
6566
elif function_name.startswith("cum_"): # Cumulative operation
6667
pandas_kwargs = {"skipna": True}
68+
elif function_name.startswith("rolling_"): # Rolling operation
69+
pandas_kwargs = {
70+
"min_periods": kwargs["min_samples"],
71+
"window": kwargs["window_size"],
72+
"center": kwargs["center"],
73+
}
6774
else: # e.g. std, var
6875
pandas_kwargs = kwargs
6976
return pandas_kwargs
@@ -487,11 +494,11 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
487494
raise NotImplementedError(msg)
488495
else:
489496
function_name: str = re.sub(r"(\w+->)", "", self._function_name)
490-
if pandas_function_name := WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT.get(
491-
function_name, AGGREGATIONS_TO_PANDAS_EQUIVALENT.get(function_name, None)
492-
):
493-
pass
494-
else:
497+
pandas_function_name = WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT.get(
498+
function_name,
499+
AGGREGATIONS_TO_PANDAS_EQUIVALENT.get(function_name),
500+
)
501+
if pandas_function_name is None:
495502
msg = (
496503
f"Unsupported function: {function_name} in `over` context.\n\n"
497504
f"Supported functions are {', '.join(WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT)}\n"
@@ -504,7 +511,6 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
504511

505512
def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
506513
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
507-
508514
if function_name == "cum_count":
509515
plx = self.__narwhals_namespace__()
510516
df = df.with_columns(~plx.col(*output_names).is_null())
@@ -527,9 +533,16 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
527533
elif reverse:
528534
columns = list(set(partition_by).union(output_names))
529535
df = df[columns][::-1]
530-
res_native = df._native_frame.groupby(partition_by)[
531-
list(output_names)
532-
].transform(pandas_function_name, **pandas_kwargs)
536+
if function_name.startswith("rolling"):
537+
rolling = df._native_frame.groupby(partition_by)[
538+
list(output_names)
539+
].rolling(**pandas_kwargs)
540+
assert pandas_function_name is not None # help mypy # noqa: S101
541+
res_native = getattr(rolling, pandas_function_name)()
542+
else:
543+
res_native = df._native_frame.groupby(partition_by)[
544+
list(output_names)
545+
].transform(pandas_function_name, **pandas_kwargs)
533546
result_frame = df._from_native_frame(res_native).rename(
534547
dict(zip(output_names, aliases))
535548
)
@@ -650,25 +663,23 @@ def cum_prod(self: Self, *, reverse: bool) -> Self:
650663
)
651664

652665
def rolling_sum(
653-
self: Self,
654-
window_size: int,
655-
*,
656-
min_samples: int | None,
657-
center: bool,
666+
self: Self, window_size: int, *, min_samples: int, center: bool
658667
) -> Self:
659668
return reuse_series_implementation(
660669
self,
661670
"rolling_sum",
662-
window_size=window_size,
663-
min_samples=min_samples,
664-
center=center,
671+
call_kwargs={
672+
"window_size": window_size,
673+
"min_samples": min_samples,
674+
"center": center,
675+
},
665676
)
666677

667678
def rolling_mean(
668679
self: Self,
669680
window_size: int,
670681
*,
671-
min_samples: int | None,
682+
min_samples: int,
672683
center: bool,
673684
) -> Self:
674685
return reuse_series_implementation(
@@ -683,7 +694,7 @@ def rolling_var(
683694
self: Self,
684695
window_size: int,
685696
*,
686-
min_samples: int | None,
697+
min_samples: int,
687698
center: bool,
688699
ddof: int,
689700
) -> Self:
@@ -700,7 +711,7 @@ def rolling_std(
700711
self: Self,
701712
window_size: int,
702713
*,
703-
min_samples: int | None,
714+
min_samples: int,
704715
center: bool,
705716
ddof: int,
706717
) -> Self:

narwhals/_pandas_like/series.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,7 @@ def rolling_sum(
922922
self: Self,
923923
window_size: int,
924924
*,
925-
min_samples: int | None,
925+
min_samples: int,
926926
center: bool,
927927
) -> Self:
928928
result = self._native_series.rolling(
@@ -934,7 +934,7 @@ def rolling_mean(
934934
self: Self,
935935
window_size: int,
936936
*,
937-
min_samples: int | None,
937+
min_samples: int,
938938
center: bool,
939939
) -> Self:
940940
result = self._native_series.rolling(
@@ -946,7 +946,7 @@ def rolling_var(
946946
self: Self,
947947
window_size: int,
948948
*,
949-
min_samples: int | None,
949+
min_samples: int,
950950
center: bool,
951951
ddof: int,
952952
) -> Self:
@@ -959,7 +959,7 @@ def rolling_std(
959959
self: Self,
960960
window_size: int,
961961
*,
962-
min_samples: int | None,
962+
min_samples: int,
963963
center: bool,
964964
ddof: int,
965965
) -> Self:

narwhals/_polars/expr.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def rolling_var(
118118
self: Self,
119119
window_size: int,
120120
*,
121-
min_samples: int | None,
121+
min_samples: int,
122122
center: bool,
123123
ddof: int,
124124
) -> Self:
@@ -144,7 +144,7 @@ def rolling_std(
144144
self: Self,
145145
window_size: int,
146146
*,
147-
min_samples: int | None,
147+
min_samples: int,
148148
center: bool,
149149
ddof: int,
150150
) -> Self:
@@ -167,11 +167,7 @@ def rolling_std(
167167
)
168168

169169
def rolling_sum(
170-
self: Self,
171-
window_size: int,
172-
*,
173-
min_samples: int | None,
174-
center: bool,
170+
self: Self, window_size: int, *, min_samples: int, center: bool
175171
) -> Self:
176172
extra_kwargs = (
177173
{"min_periods": min_samples}
@@ -191,7 +187,7 @@ def rolling_mean(
191187
self: Self,
192188
window_size: int,
193189
*,
194-
min_samples: int | None,
190+
min_samples: int,
195191
center: bool,
196192
) -> Self:
197193
extra_kwargs = (

narwhals/_polars/series.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def rolling_var(
309309
self: Self,
310310
window_size: int,
311311
*,
312-
min_samples: int | None,
312+
min_samples: int,
313313
center: bool,
314314
ddof: int,
315315
) -> Self:
@@ -336,7 +336,7 @@ def rolling_std(
336336
self: Self,
337337
window_size: int,
338338
*,
339-
min_samples: int | None,
339+
min_samples: int,
340340
center: bool,
341341
ddof: int,
342342
) -> Self:
@@ -363,7 +363,7 @@ def rolling_sum(
363363
self: Self,
364364
window_size: int,
365365
*,
366-
min_samples: int | None,
366+
min_samples: int,
367367
center: bool,
368368
) -> Self:
369369
extra_kwargs = (
@@ -384,7 +384,7 @@ def rolling_mean(
384384
self: Self,
385385
window_size: int,
386386
*,
387-
min_samples: int | None,
387+
min_samples: int,
388388
center: bool,
389389
) -> Self:
390390
extra_kwargs = (

narwhals/_spark_like/expr.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,12 +569,38 @@ def func(
569569
self._Window()
570570
.partitionBy(list(partition_by))
571571
.orderBy(order_by_cols)
572-
.rangeBetween(self._Window().unboundedPreceding, 0)
572+
.rowsBetween(self._Window().unboundedPreceding, 0)
573573
)
574574
return self._F.sum(_input).over(window)
575575

576576
return self._with_window_function(func)
577577

578+
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
579+
if center:
580+
half = (window_size - 1) // 2
581+
remainder = (window_size - 1) % 2
582+
start = self._Window().currentRow - half - remainder
583+
end = self._Window().currentRow + half
584+
else:
585+
start = self._Window().currentRow - window_size + 1
586+
end = self._Window().currentRow
587+
588+
def func(
589+
_input: Column, partition_by: Sequence[str], order_by: Sequence[str]
590+
) -> Column:
591+
window = (
592+
self._Window()
593+
.partitionBy(list(partition_by))
594+
.orderBy([self._F.col(x).asc_nulls_first() for x in order_by])
595+
.rowsBetween(start, end)
596+
)
597+
return self._F.when(
598+
self._F.count(_input).over(window) >= min_samples,
599+
self._F.sum(_input).over(window),
600+
)
601+
602+
return self._with_window_function(func)
603+
578604
@property
579605
def str(self: Self) -> SparkLikeExprStringNamespace:
580606
return SparkLikeExprStringNamespace(self)
@@ -602,7 +628,6 @@ def list(self: Self) -> SparkLikeExprListNamespace:
602628
sample = not_implemented()
603629
map_batches = not_implemented()
604630
ewm_mean = not_implemented()
605-
rolling_sum = not_implemented()
606631
rolling_mean = not_implemented()
607632
rolling_var = not_implemented()
608633
rolling_std = not_implemented()

0 commit comments

Comments
 (0)