| 
10 | 10 | from hypothesis import given  | 
11 | 11 | 
 
  | 
12 | 12 | import narwhals.stable.v1 as nw  | 
 | 13 | +from tests.utils import DUCKDB_VERSION  | 
13 | 14 | from tests.utils import PANDAS_VERSION  | 
 | 15 | +from tests.utils import POLARS_VERSION  | 
 | 16 | +from tests.utils import Constructor  | 
14 | 17 | from tests.utils import ConstructorEager  | 
15 | 18 | from tests.utils import assert_equal_data  | 
16 | 19 | 
 
  | 
@@ -95,3 +98,110 @@ def test_rolling_mean_hypothesis(center: bool, values: list[float]) -> None:  #  | 
95 | 98 |     )  | 
96 | 99 |     expected_dict = nw.from_native(expected, eager_only=True).to_dict(as_series=False)  | 
97 | 100 |     assert_equal_data(result, expected_dict)  | 
 | 101 | + | 
 | 102 | + | 
 | 103 | +@pytest.mark.filterwarnings(  | 
 | 104 | +    "ignore:`Expr.rolling_mean` is being called from the stable API although considered an unstable feature."  | 
 | 105 | +)  | 
 | 106 | +@pytest.mark.parametrize(  | 
 | 107 | +    ("expected_a", "window_size", "min_samples", "center"),  | 
 | 108 | +    [  | 
 | 109 | +        ([None, None, 1.5, None, None, 5, 8.5], 2, None, False),  | 
 | 110 | +        ([None, None, 1.5, None, None, 5, 8.5], 2, 2, False),  | 
 | 111 | +        ([None, None, 1.5, 1.5, None, 5, 7.0], 3, 2, False),  | 
 | 112 | +        ([1, None, 1.5, 1.5, 4, 5, 7], 3, 1, False),  | 
 | 113 | +        ([1.5, 1, 1.5, 2, 5, 7, 8.5], 3, 1, True),  | 
 | 114 | +        ([1.5, 1, 1.5, 1.5, 5, 7, 7], 4, 1, True),  | 
 | 115 | +        ([1.5, 1.5, 1.5, 1.5, 7, 7, 7], 5, 1, True),  | 
 | 116 | +    ],  | 
 | 117 | +)  | 
 | 118 | +def test_rolling_mean_expr_lazy_grouped(  | 
 | 119 | +    constructor: Constructor,  | 
 | 120 | +    expected_a: list[float],  | 
 | 121 | +    window_size: int,  | 
 | 122 | +    min_samples: int,  | 
 | 123 | +    request: pytest.FixtureRequest,  | 
 | 124 | +    *,  | 
 | 125 | +    center: bool,  | 
 | 126 | +) -> None:  | 
 | 127 | +    if ("polars" in str(constructor) and POLARS_VERSION < (1, 10)) or (  | 
 | 128 | +        "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)  | 
 | 129 | +    ):  | 
 | 130 | +        pytest.skip()  | 
 | 131 | +    if "pandas" in str(constructor):  | 
 | 132 | +        pytest.skip()  | 
 | 133 | +    if any(x in str(constructor) for x in ("dask", "pyarrow_table")):  | 
 | 134 | +        request.applymarker(pytest.mark.xfail)  | 
 | 135 | +    if "cudf" in str(constructor) and center:  | 
 | 136 | +        # center is not implemented for offset-based windows  | 
 | 137 | +        request.applymarker(pytest.mark.xfail)  | 
 | 138 | +    if "modin" in str(constructor):  | 
 | 139 | +        # unreliable  | 
 | 140 | +        pytest.skip()  | 
 | 141 | +    data = {  | 
 | 142 | +        "a": [1, None, 2, None, 4, 6, 11],  | 
 | 143 | +        "g": [1, 1, 1, 1, 2, 2, 2],  | 
 | 144 | +        "b": [1, None, 2, 3, 4, 5, 6],  | 
 | 145 | +        "i": list(range(7)),  | 
 | 146 | +    }  | 
 | 147 | +    df = nw.from_native(constructor(data))  | 
 | 148 | +    result = (  | 
 | 149 | +        df.with_columns(  | 
 | 150 | +            nw.col("a")  | 
 | 151 | +            .rolling_mean(window_size, min_samples=min_samples, center=center)  | 
 | 152 | +            .over("g", order_by="b")  | 
 | 153 | +        )  | 
 | 154 | +        .sort("i")  | 
 | 155 | +        .select("a")  | 
 | 156 | +    )  | 
 | 157 | +    expected = {"a": expected_a}  | 
 | 158 | +    assert_equal_data(result, expected)  | 
 | 159 | + | 
 | 160 | + | 
 | 161 | +@pytest.mark.filterwarnings(  | 
 | 162 | +    "ignore:`Expr.rolling_mean` is being called from the stable API although considered an unstable feature."  | 
 | 163 | +)  | 
 | 164 | +@pytest.mark.parametrize(  | 
 | 165 | +    ("expected_a", "window_size", "min_samples", "center"),  | 
 | 166 | +    [  | 
 | 167 | +        ([None, None, 1.5, None, None, 5, 8.5], 2, None, False),  | 
 | 168 | +        ([None, None, 1.5, None, None, 5, 8.5], 2, 2, False),  | 
 | 169 | +        ([None, None, 1.5, 1.5, 3, 5, 7], 3, 2, False),  | 
 | 170 | +        ([1, None, 1.5, 1.5, 3, 5, 7], 3, 1, False),  | 
 | 171 | +        ([1.5, 1, 1.5, 3, 5, 7, 8.5], 3, 1, True),  | 
 | 172 | +        ([1.5, 1, 1.5, 2.3333333333333335, 4, 7, 7], 4, 1, True),  | 
 | 173 | +        ([1.5, 1.5, 2.3333333333333335, 3.25, 5.75, 7.0, 7.0], 5, 1, True),  | 
 | 174 | +    ],  | 
 | 175 | +)  | 
 | 176 | +def test_rolling_mean_expr_lazy_ungrouped(  | 
 | 177 | +    constructor: Constructor,  | 
 | 178 | +    expected_a: list[float],  | 
 | 179 | +    window_size: int,  | 
 | 180 | +    min_samples: int,  | 
 | 181 | +    *,  | 
 | 182 | +    center: bool,  | 
 | 183 | +) -> None:  | 
 | 184 | +    if ("polars" in str(constructor) and POLARS_VERSION < (1, 10)) or (  | 
 | 185 | +        "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)  | 
 | 186 | +    ):  | 
 | 187 | +        pytest.skip()  | 
 | 188 | +    if "modin" in str(constructor):  | 
 | 189 | +        # unreliable  | 
 | 190 | +        pytest.skip()  | 
 | 191 | +    data = {  | 
 | 192 | +        "a": [1, None, 2, None, 4, 6, 11],  | 
 | 193 | +        "b": [1, None, 2, 3, 4, 5, 6],  | 
 | 194 | +        "i": list(range(7)),  | 
 | 195 | +    }  | 
 | 196 | +    df = nw.from_native(constructor(data))  | 
 | 197 | +    result = (  | 
 | 198 | +        df.with_columns(  | 
 | 199 | +            nw.col("a")  | 
 | 200 | +            .rolling_mean(window_size, min_samples=min_samples, center=center)  | 
 | 201 | +            .over(order_by="b")  | 
 | 202 | +        )  | 
 | 203 | +        .select("a", "i")  | 
 | 204 | +        .sort("i")  | 
 | 205 | +    )  | 
 | 206 | +    expected = {"a": expected_a, "i": list(range(7))}  | 
 | 207 | +    assert_equal_data(result, expected)  | 
0 commit comments