Skip to content

Commit f28d7f8

Browse files
authored
Fix DataArrayRolling.__iter__ with center=True (#6744)
* new test_rolling module * fix rolling iter with center=True * add fix to whats-new * fix DatasetRolling test names * small code simplification
1 parent e5fcd79 commit f28d7f8

File tree

3 files changed

+19
-16
lines changed

3 files changed

+19
-16
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ Bug fixes
5454
- :py:meth:`open_dataset` with dask and ``~`` in the path now resolves the home directory
5555
instead of raising an error. (:issue:`6707`, :pull:`6710`)
5656
By `Michael Niklas <https://github.com/headtr1ck>`_.
57+
- :py:meth:`DataArrayRolling.__iter__` with ``center=True`` now works correctly.
58+
(:issue:`6739`, :pull:`6744`)
59+
By `Michael Niklas <https://github.com/headtr1ck>`_.
5760

5861
Documentation
5962
~~~~~~~~~~~~~

xarray/core/rolling.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,16 +267,21 @@ def __init__(
267267
# TODO legacy attribute
268268
self.window_labels = self.obj[self.dim[0]]
269269

270-
def __iter__(self) -> Iterator[tuple[RollingKey, DataArray]]:
270+
def __iter__(self) -> Iterator[tuple[DataArray, DataArray]]:
271271
if self.ndim > 1:
272272
raise ValueError("__iter__ is only supported for 1d-rolling")
273-
stops = np.arange(1, len(self.window_labels) + 1)
274-
starts = stops - int(self.window[0])
275-
starts[: int(self.window[0])] = 0
273+
274+
dim0 = self.dim[0]
275+
window0 = int(self.window[0])
276+
offset = (window0 + 1) // 2 if self.center[0] else 1
277+
stops = np.arange(offset, self.obj.sizes[dim0] + offset)
278+
starts = stops - window0
279+
starts[: window0 - offset] = 0
280+
276281
for (label, start, stop) in zip(self.window_labels, starts, stops):
277-
window = self.obj.isel({self.dim[0]: slice(start, stop)})
282+
window = self.obj.isel({dim0: slice(start, stop)})
278283

279-
counts = window.count(dim=self.dim[0])
284+
counts = window.count(dim=dim0)
280285
window = window.where(counts >= self.min_periods)
281286

282287
yield (label, window)

xarray/tests/test_rolling.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727

2828
class TestDataArrayRolling:
2929
@pytest.mark.parametrize("da", (1, 2), indirect=True)
30-
def test_rolling_iter(self, da) -> None:
31-
rolling_obj = da.rolling(time=7)
30+
@pytest.mark.parametrize("center", [True, False])
31+
@pytest.mark.parametrize("size", [1, 2, 3, 7])
32+
def test_rolling_iter(self, da: DataArray, center: bool, size: int) -> None:
33+
rolling_obj = da.rolling(time=size, center=center)
3234
rolling_obj_mean = rolling_obj.mean()
3335

3436
assert len(rolling_obj.window_labels) == len(da["time"])
@@ -40,14 +42,7 @@ def test_rolling_iter(self, da) -> None:
4042
actual = rolling_obj_mean.isel(time=i)
4143
expected = window_da.mean("time")
4244

43-
# TODO add assert_allclose_with_nan, which compares nan position
44-
# as well as the closeness of the values.
45-
assert_array_equal(actual.isnull(), expected.isnull())
46-
if (~actual.isnull()).sum() > 0:
47-
np.allclose(
48-
actual.values[actual.values.nonzero()],
49-
expected.values[expected.values.nonzero()],
50-
)
45+
np.testing.assert_allclose(actual.values, expected.values)
5146

5247
@pytest.mark.parametrize("da", (1,), indirect=True)
5348
def test_rolling_repr(self, da) -> None:

0 commit comments

Comments
 (0)