Skip to content

Commit bbf2aa3

Browse files
authored
feat: add Series|Expr.rolling_sum method (#1395)
1 parent 836f086 commit bbf2aa3

File tree

13 files changed

+766
-1
lines changed

13 files changed

+766
-1
lines changed

docs/api-reference/exceptions.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77
- ColumnNotFoundError
88
- InvalidIntoExprError
99
- InvalidOperationError
10+
- NarwhalsUnstableWarning
1011
show_source: false
1112
show_bases: false

docs/api-reference/expr.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
- pipe
4545
- quantile
4646
- replace_strict
47+
- rolling_sum
4748
- round
4849
- sample
4950
- shift

docs/api-reference/series.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
- quantile
5252
- rename
5353
- replace_strict
54+
- rolling_sum
5455
- round
5556
- sample
5657
- scatter

narwhals/_arrow/expr.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,21 @@ def cum_max(self: Self, *, reverse: bool) -> Self:
453453
def cum_prod(self: Self, *, reverse: bool) -> Self:
454454
return reuse_series_implementation(self, "cum_prod", reverse=reverse)
455455

456+
def rolling_sum(
457+
self: Self,
458+
window_size: int,
459+
*,
460+
min_periods: int | None,
461+
center: bool,
462+
) -> Self:
463+
return reuse_series_implementation(
464+
self,
465+
"rolling_sum",
466+
window_size=window_size,
467+
min_periods=min_periods,
468+
center=center,
469+
)
470+
456471
@property
457472
def dt(self: Self) -> ArrowExprDateTimeNamespace:
458473
return ArrowExprDateTimeNamespace(self)

narwhals/_arrow/series.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,54 @@ def cum_prod(self: Self, *, reverse: bool) -> Self:
874874
)
875875
return self._from_native_series(result)
876876

877+
def rolling_sum(
878+
self: Self,
879+
window_size: int,
880+
*,
881+
min_periods: int | None,
882+
center: bool,
883+
) -> Self:
884+
import pyarrow as pa # ignore-banned-import
885+
import pyarrow.compute as pc # ignore-banned-import
886+
887+
min_periods = min_periods if min_periods is not None else window_size
888+
if center:
889+
offset_left = window_size // 2
890+
offset_right = offset_left - (
891+
window_size % 2 == 0
892+
) # subtract one if window_size is even
893+
894+
native_series = self._native_series
895+
896+
pad_left = pa.array([None] * offset_left, type=native_series.type)
897+
pad_right = pa.array([None] * offset_right, type=native_series.type)
898+
padded_arr = self._from_native_series(
899+
pa.concat_arrays([pad_left, native_series.combine_chunks(), pad_right])
900+
)
901+
else:
902+
padded_arr = self
903+
904+
cum_sum = padded_arr.cum_sum(reverse=False).fill_null(strategy="forward")
905+
rolling_sum = (
906+
cum_sum - cum_sum.shift(window_size).fill_null(0)
907+
if window_size != 0
908+
else cum_sum
909+
)
910+
911+
valid_count = padded_arr.cum_count(reverse=False)
912+
count_in_window = valid_count - valid_count.shift(window_size).fill_null(0)
913+
914+
result = self._from_native_series(
915+
pc.if_else(
916+
(count_in_window >= min_periods)._native_series,
917+
rolling_sum._native_series,
918+
None,
919+
)
920+
)
921+
if center:
922+
result = result[offset_left + offset_right :]
923+
return result
924+
877925
def __iter__(self: Self) -> Iterator[Any]:
878926
yield from self._native_series.__iter__()
879927

narwhals/_dask/expr.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,32 @@ def is_finite(self: Self) -> Self:
840840
returns_scalar=False,
841841
)
842842

843+
def rolling_sum(
844+
self: Self,
845+
window_size: int,
846+
*,
847+
min_periods: int | None,
848+
center: bool,
849+
) -> Self:
850+
def func(
851+
_input: dask_expr.Series,
852+
_window: int,
853+
_min_periods: int | None,
854+
_center: bool, # noqa: FBT001
855+
) -> dask_expr.Series:
856+
return _input.rolling(
857+
window=_window, min_periods=_min_periods, center=_center
858+
).sum()
859+
860+
return self._from_call(
861+
func,
862+
"rolling_sum",
863+
window_size,
864+
min_periods,
865+
center,
866+
returns_scalar=False,
867+
)
868+
843869

844870
class DaskExprStringNamespace:
845871
def __init__(self, expr: DaskExpr) -> None:

narwhals/_pandas_like/expr.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,21 @@ def cum_max(self: Self, *, reverse: bool) -> Self:
464464
def cum_prod(self: Self, *, reverse: bool) -> Self:
465465
return reuse_series_implementation(self, "cum_prod", reverse=reverse)
466466

467+
def rolling_sum(
468+
self: Self,
469+
window_size: int,
470+
*,
471+
min_periods: int | None,
472+
center: bool,
473+
) -> Self:
474+
return reuse_series_implementation(
475+
self,
476+
"rolling_sum",
477+
window_size=window_size,
478+
min_periods=min_periods,
479+
center=center,
480+
)
481+
467482
@property
468483
def str(self: Self) -> PandasLikeExprStringNamespace:
469484
return PandasLikeExprStringNamespace(self)

narwhals/_pandas_like/series.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def fill_null(
457457
value: Any | None = None,
458458
strategy: Literal["forward", "backward"] | None = None,
459459
limit: int | None = None,
460-
) -> PandasLikeSeries:
460+
) -> Self:
461461
ser = self._native_series
462462
if value is not None:
463463
res_ser = self._from_native_series(ser.fillna(value=value))
@@ -798,6 +798,18 @@ def cum_prod(self: Self, *, reverse: bool) -> Self:
798798
)
799799
return self._from_native_series(result)
800800

801+
def rolling_sum(
802+
self: Self,
803+
window_size: int,
804+
*,
805+
min_periods: int | None,
806+
center: bool,
807+
) -> Self:
808+
result = self._native_series.rolling(
809+
window=window_size, min_periods=min_periods, center=center
810+
).sum()
811+
return self._from_native_series(result)
812+
801813
def __iter__(self: Self) -> Iterator[Any]:
802814
yield from self._native_series.__iter__()
803815

narwhals/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,7 @@ def from_invalid_type(cls, invalid_type: type) -> InvalidIntoExprError:
6060
" column with literal value `0`."
6161
)
6262
return InvalidIntoExprError(message)
63+
64+
65+
class NarwhalsUnstableWarning(UserWarning):
66+
"""Warning issued when a method or function is considered unstable in the stable api."""

narwhals/expr.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import TypeVar
1212

1313
from narwhals.dependencies import is_numpy_array
14+
from narwhals.exceptions import InvalidOperationError
1415
from narwhals.utils import flatten
1516

1617
if TYPE_CHECKING:
@@ -2990,6 +2991,123 @@ def cum_prod(self: Self, *, reverse: bool = False) -> Self:
29902991
"""
29912992
return self.__class__(lambda plx: self._call(plx).cum_prod(reverse=reverse))
29922993

2994+
def rolling_sum(
2995+
self: Self,
2996+
window_size: int,
2997+
*,
2998+
min_periods: int | None = None,
2999+
center: bool = False,
3000+
) -> Self:
3001+
"""Apply a rolling sum (moving sum) over the values.
3002+
3003+
!!! warning
3004+
This functionality is considered **unstable**. It may be changed at any point
3005+
without it being considered a breaking change.
3006+
3007+
A window of length `window_size` will traverse the values. The resulting values
3008+
will be aggregated to their sum.
3009+
3010+
The window at a given row will include the row itself and the `window_size - 1`
3011+
elements before it.
3012+
3013+
Arguments:
3014+
window_size: The length of the window in number of elements. It must be a
3015+
strictly positive integer.
3016+
min_periods: The number of values in the window that should be non-null before
3017+
computing a result. If set to `None` (default), it will be set equal to
3018+
`window_size`. If provided, it must be a strictly positive integer, and
3019+
less than or equal to `window_size`
3020+
center: Set the labels at the center of the window.
3021+
3022+
Returns:
3023+
A new expression.
3024+
3025+
Examples:
3026+
>>> import narwhals as nw
3027+
>>> import pandas as pd
3028+
>>> import polars as pl
3029+
>>> import pyarrow as pa
3030+
>>> data = {"a": [1.0, 2.0, None, 4.0]}
3031+
>>> df_pd = pd.DataFrame(data)
3032+
>>> df_pl = pl.DataFrame(data)
3033+
>>> df_pa = pa.table(data)
3034+
3035+
We define a library agnostic function:
3036+
3037+
>>> @nw.narwhalify
3038+
... def func(df):
3039+
... return df.with_columns(
3040+
... b=nw.col("a").rolling_sum(window_size=3, min_periods=1)
3041+
... )
3042+
3043+
We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`:
3044+
3045+
>>> func(df_pd)
3046+
a b
3047+
0 1.0 1.0
3048+
1 2.0 3.0
3049+
2 NaN 3.0
3050+
3 4.0 6.0
3051+
3052+
>>> func(df_pl)
3053+
shape: (4, 2)
3054+
┌──────┬─────┐
3055+
│ a ┆ b │
3056+
│ --- ┆ --- │
3057+
│ f64 ┆ f64 │
3058+
╞══════╪═════╡
3059+
│ 1.0 ┆ 1.0 │
3060+
│ 2.0 ┆ 3.0 │
3061+
│ null ┆ 3.0 │
3062+
│ 4.0 ┆ 6.0 │
3063+
└──────┴─────┘
3064+
3065+
>>> func(df_pa) # doctest:+ELLIPSIS
3066+
pyarrow.Table
3067+
a: double
3068+
b: double
3069+
----
3070+
a: [[1,2,null,4]]
3071+
b: [[1,3,3,6]]
3072+
"""
3073+
if window_size < 1:
3074+
msg = "window_size must be greater or equal than 1"
3075+
raise ValueError(msg)
3076+
3077+
if not isinstance(window_size, int):
3078+
_type = window_size.__class__.__name__
3079+
msg = (
3080+
f"argument 'window_size': '{_type}' object cannot be "
3081+
"interpreted as an integer"
3082+
)
3083+
raise TypeError(msg)
3084+
3085+
if min_periods is not None:
3086+
if min_periods < 1:
3087+
msg = "min_periods must be greater or equal than 1"
3088+
raise ValueError(msg)
3089+
3090+
if not isinstance(min_periods, int):
3091+
_type = min_periods.__class__.__name__
3092+
msg = (
3093+
f"argument 'min_periods': '{_type}' object cannot be "
3094+
"interpreted as an integer"
3095+
)
3096+
raise TypeError(msg)
3097+
if min_periods > window_size:
3098+
msg = "`min_periods` must be less or equal than `window_size`"
3099+
raise InvalidOperationError(msg)
3100+
else:
3101+
min_periods = window_size
3102+
3103+
return self.__class__(
3104+
lambda plx: self._call(plx).rolling_sum(
3105+
window_size=window_size,
3106+
min_periods=min_periods,
3107+
center=center,
3108+
)
3109+
)
3110+
29933111
@property
29943112
def str(self: Self) -> ExprStringNamespace[Self]:
29953113
return ExprStringNamespace(self)

0 commit comments

Comments
 (0)