|
11 | 11 | from typing import TypeVar |
12 | 12 |
|
13 | 13 | from narwhals.dependencies import is_numpy_array |
| 14 | +from narwhals.exceptions import InvalidOperationError |
14 | 15 | from narwhals.utils import flatten |
15 | 16 |
|
16 | 17 | if TYPE_CHECKING: |
@@ -2990,6 +2991,123 @@ def cum_prod(self: Self, *, reverse: bool = False) -> Self: |
2990 | 2991 | """ |
2991 | 2992 | return self.__class__(lambda plx: self._call(plx).cum_prod(reverse=reverse)) |
2992 | 2993 |
|
| 2994 | + def rolling_sum( |
| 2995 | + self: Self, |
| 2996 | + window_size: int, |
| 2997 | + *, |
| 2998 | + min_periods: int | None = None, |
| 2999 | + center: bool = False, |
| 3000 | + ) -> Self: |
| 3001 | + """Apply a rolling sum (moving sum) over the values. |
| 3002 | +
|
| 3003 | + !!! warning |
| 3004 | + This functionality is considered **unstable**. It may be changed at any point |
| 3005 | + without it being considered a breaking change. |
| 3006 | +
|
| 3007 | + A window of length `window_size` will traverse the values. The resulting values |
| 3008 | + will be aggregated to their sum. |
| 3009 | +
|
| 3010 | + The window at a given row will include the row itself and the `window_size - 1` |
| 3011 | + elements before it. |
| 3012 | +
|
| 3013 | + Arguments: |
| 3014 | + window_size: The length of the window in number of elements. It must be a |
| 3015 | + strictly positive integer. |
| 3016 | + min_periods: The number of values in the window that should be non-null before |
| 3017 | + computing a result. If set to `None` (default), it will be set equal to |
| 3018 | + `window_size`. If provided, it must be a strictly positive integer, and |
| 3019 | + less than or equal to `window_size` |
| 3020 | + center: Set the labels at the center of the window. |
| 3021 | +
|
| 3022 | + Returns: |
| 3023 | + A new expression. |
| 3024 | +
|
| 3025 | + Examples: |
| 3026 | + >>> import narwhals as nw |
| 3027 | + >>> import pandas as pd |
| 3028 | + >>> import polars as pl |
| 3029 | + >>> import pyarrow as pa |
| 3030 | + >>> data = {"a": [1.0, 2.0, None, 4.0]} |
| 3031 | + >>> df_pd = pd.DataFrame(data) |
| 3032 | + >>> df_pl = pl.DataFrame(data) |
| 3033 | + >>> df_pa = pa.table(data) |
| 3034 | +
|
| 3035 | + We define a library agnostic function: |
| 3036 | +
|
| 3037 | + >>> @nw.narwhalify |
| 3038 | + ... def func(df): |
| 3039 | + ... return df.with_columns( |
| 3040 | + ... b=nw.col("a").rolling_sum(window_size=3, min_periods=1) |
| 3041 | + ... ) |
| 3042 | +
|
| 3043 | + We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`: |
| 3044 | +
|
| 3045 | + >>> func(df_pd) |
| 3046 | + a b |
| 3047 | + 0 1.0 1.0 |
| 3048 | + 1 2.0 3.0 |
| 3049 | + 2 NaN 3.0 |
| 3050 | + 3 4.0 6.0 |
| 3051 | +
|
| 3052 | + >>> func(df_pl) |
| 3053 | + shape: (4, 2) |
| 3054 | + ┌──────┬─────┐ |
| 3055 | + │ a ┆ b │ |
| 3056 | + │ --- ┆ --- │ |
| 3057 | + │ f64 ┆ f64 │ |
| 3058 | + ╞══════╪═════╡ |
| 3059 | + │ 1.0 ┆ 1.0 │ |
| 3060 | + │ 2.0 ┆ 3.0 │ |
| 3061 | + │ null ┆ 3.0 │ |
| 3062 | + │ 4.0 ┆ 6.0 │ |
| 3063 | + └──────┴─────┘ |
| 3064 | +
|
| 3065 | + >>> func(df_pa) # doctest:+ELLIPSIS |
| 3066 | + pyarrow.Table |
| 3067 | + a: double |
| 3068 | + b: double |
| 3069 | + ---- |
| 3070 | + a: [[1,2,null,4]] |
| 3071 | + b: [[1,3,3,6]] |
| 3072 | + """ |
| 3073 | + if window_size < 1: |
| 3074 | + msg = "window_size must be greater or equal than 1" |
| 3075 | + raise ValueError(msg) |
| 3076 | + |
| 3077 | + if not isinstance(window_size, int): |
| 3078 | + _type = window_size.__class__.__name__ |
| 3079 | + msg = ( |
| 3080 | + f"argument 'window_size': '{_type}' object cannot be " |
| 3081 | + "interpreted as an integer" |
| 3082 | + ) |
| 3083 | + raise TypeError(msg) |
| 3084 | + |
| 3085 | + if min_periods is not None: |
| 3086 | + if min_periods < 1: |
| 3087 | + msg = "min_periods must be greater or equal than 1" |
| 3088 | + raise ValueError(msg) |
| 3089 | + |
| 3090 | + if not isinstance(min_periods, int): |
| 3091 | + _type = min_periods.__class__.__name__ |
| 3092 | + msg = ( |
| 3093 | + f"argument 'min_periods': '{_type}' object cannot be " |
| 3094 | + "interpreted as an integer" |
| 3095 | + ) |
| 3096 | + raise TypeError(msg) |
| 3097 | + if min_periods > window_size: |
| 3098 | + msg = "`min_periods` must be less or equal than `window_size`" |
| 3099 | + raise InvalidOperationError(msg) |
| 3100 | + else: |
| 3101 | + min_periods = window_size |
| 3102 | + |
| 3103 | + return self.__class__( |
| 3104 | + lambda plx: self._call(plx).rolling_sum( |
| 3105 | + window_size=window_size, |
| 3106 | + min_periods=min_periods, |
| 3107 | + center=center, |
| 3108 | + ) |
| 3109 | + ) |
| 3110 | + |
2993 | 3111 | @property |
2994 | 3112 | def str(self: Self) -> ExprStringNamespace[Self]: |
2995 | 3113 | return ExprStringNamespace(self) |
|
0 commit comments