Skip to content

Commit e1936a9

Browse files
Ensure Coarsen.construct keeps all coords (#7233)
* test * fix * whatsnew * group related tests into a class * Update xarray/core/rolling.py * Update xarray/core/rolling.py Co-authored-by: Deepak Cherian <[email protected]>
1 parent 51d37d1 commit e1936a9

File tree

3 files changed

+89
-64
lines changed

3 files changed

+89
-64
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ Bug fixes
6161
now reopens the file from scratch for h5netcdf and scipy netCDF backends,
6262
rather than reusing a cached version (:issue:`4240`, :issue:`4862`).
6363
By `Stephan Hoyer <https://github.com/shoyer>`_.
64+
- Fixed bug where :py:meth:`Dataset.coarsen.construct` would demote non-dimension coordinates to variables. (:pull:`7233`)
65+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
6466
- Raise a TypeError when trying to plot empty data (:issue:`7156`, :pull:`7228`).
6567
By `Michael Niklas <https://github.com/headtr1ck>`_.
6668

xarray/core/rolling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,10 @@ def construct(
973973
else:
974974
reshaped[key] = var
975975

976-
should_be_coords = set(window_dim) & set(self.obj.coords)
976+
# should handle window_dim being unindexed
977+
should_be_coords = (set(window_dim) & set(self.obj.coords)) | set(
978+
self.obj.coords
979+
)
977980
result = reshaped.set_coords(should_be_coords)
978981
if isinstance(self.obj, DataArray):
979982
return self.obj._from_temp_dataset(result)

xarray/tests/test_coarsen.py

Lines changed: 83 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -250,71 +250,91 @@ def test_coarsen_da_reduce(da, window, name) -> None:
250250
assert_allclose(actual, expected)
251251

252252

253-
@pytest.mark.parametrize("dask", [True, False])
254-
def test_coarsen_construct(dask: bool) -> None:
255-
256-
ds = Dataset(
257-
{
258-
"vart": ("time", np.arange(48), {"a": "b"}),
259-
"varx": ("x", np.arange(10), {"a": "b"}),
260-
"vartx": (("x", "time"), np.arange(480).reshape(10, 48), {"a": "b"}),
261-
"vary": ("y", np.arange(12)),
262-
},
263-
coords={"time": np.arange(48), "y": np.arange(12)},
264-
attrs={"foo": "bar"},
265-
)
266-
267-
if dask and has_dask:
268-
ds = ds.chunk({"x": 4, "time": 10})
269-
270-
expected = xr.Dataset(attrs={"foo": "bar"})
271-
expected["vart"] = (("year", "month"), ds.vart.data.reshape((-1, 12)), {"a": "b"})
272-
expected["varx"] = (("x", "x_reshaped"), ds.varx.data.reshape((-1, 5)), {"a": "b"})
273-
expected["vartx"] = (
274-
("x", "x_reshaped", "year", "month"),
275-
ds.vartx.data.reshape(2, 5, 4, 12),
276-
{"a": "b"},
277-
)
278-
expected["vary"] = ds.vary
279-
expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12)))
280-
281-
with raise_if_dask_computes():
282-
actual = ds.coarsen(time=12, x=5).construct(
283-
{"time": ("year", "month"), "x": ("x", "x_reshaped")}
253+
class TestCoarsenConstruct:
254+
@pytest.mark.parametrize("dask", [True, False])
255+
def test_coarsen_construct(self, dask: bool) -> None:
256+
257+
ds = Dataset(
258+
{
259+
"vart": ("time", np.arange(48), {"a": "b"}),
260+
"varx": ("x", np.arange(10), {"a": "b"}),
261+
"vartx": (("x", "time"), np.arange(480).reshape(10, 48), {"a": "b"}),
262+
"vary": ("y", np.arange(12)),
263+
},
264+
coords={"time": np.arange(48), "y": np.arange(12)},
265+
attrs={"foo": "bar"},
284266
)
285-
assert_identical(actual, expected)
286267

287-
with raise_if_dask_computes():
288-
actual = ds.coarsen(time=12, x=5).construct(
289-
time=("year", "month"), x=("x", "x_reshaped")
290-
)
291-
assert_identical(actual, expected)
268+
if dask and has_dask:
269+
ds = ds.chunk({"x": 4, "time": 10})
292270

293-
with raise_if_dask_computes():
294-
actual = ds.coarsen(time=12, x=5).construct(
295-
{"time": ("year", "month"), "x": ("x", "x_reshaped")}, keep_attrs=False
271+
expected = xr.Dataset(attrs={"foo": "bar"})
272+
expected["vart"] = (
273+
("year", "month"),
274+
ds.vart.data.reshape((-1, 12)),
275+
{"a": "b"},
296276
)
297-
for var in actual:
298-
assert actual[var].attrs == {}
299-
assert actual.attrs == {}
300-
301-
with raise_if_dask_computes():
302-
actual = ds.vartx.coarsen(time=12, x=5).construct(
303-
{"time": ("year", "month"), "x": ("x", "x_reshaped")}
277+
expected["varx"] = (
278+
("x", "x_reshaped"),
279+
ds.varx.data.reshape((-1, 5)),
280+
{"a": "b"},
304281
)
305-
assert_identical(actual, expected["vartx"])
306-
307-
with pytest.raises(ValueError):
308-
ds.coarsen(time=12).construct(foo="bar")
309-
310-
with pytest.raises(ValueError):
311-
ds.coarsen(time=12, x=2).construct(time=("year", "month"))
312-
313-
with pytest.raises(ValueError):
314-
ds.coarsen(time=12).construct()
315-
316-
with pytest.raises(ValueError):
317-
ds.coarsen(time=12).construct(time="bar")
318-
319-
with pytest.raises(ValueError):
320-
ds.coarsen(time=12).construct(time=("bar",))
282+
expected["vartx"] = (
283+
("x", "x_reshaped", "year", "month"),
284+
ds.vartx.data.reshape(2, 5, 4, 12),
285+
{"a": "b"},
286+
)
287+
expected["vary"] = ds.vary
288+
expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12)))
289+
290+
with raise_if_dask_computes():
291+
actual = ds.coarsen(time=12, x=5).construct(
292+
{"time": ("year", "month"), "x": ("x", "x_reshaped")}
293+
)
294+
assert_identical(actual, expected)
295+
296+
with raise_if_dask_computes():
297+
actual = ds.coarsen(time=12, x=5).construct(
298+
time=("year", "month"), x=("x", "x_reshaped")
299+
)
300+
assert_identical(actual, expected)
301+
302+
with raise_if_dask_computes():
303+
actual = ds.coarsen(time=12, x=5).construct(
304+
{"time": ("year", "month"), "x": ("x", "x_reshaped")}, keep_attrs=False
305+
)
306+
for var in actual:
307+
assert actual[var].attrs == {}
308+
assert actual.attrs == {}
309+
310+
with raise_if_dask_computes():
311+
actual = ds.vartx.coarsen(time=12, x=5).construct(
312+
{"time": ("year", "month"), "x": ("x", "x_reshaped")}
313+
)
314+
assert_identical(actual, expected["vartx"])
315+
316+
with pytest.raises(ValueError):
317+
ds.coarsen(time=12).construct(foo="bar")
318+
319+
with pytest.raises(ValueError):
320+
ds.coarsen(time=12, x=2).construct(time=("year", "month"))
321+
322+
with pytest.raises(ValueError):
323+
ds.coarsen(time=12).construct()
324+
325+
with pytest.raises(ValueError):
326+
ds.coarsen(time=12).construct(time="bar")
327+
328+
with pytest.raises(ValueError):
329+
ds.coarsen(time=12).construct(time=("bar",))
330+
331+
def test_coarsen_construct_keeps_all_coords(self):
332+
da = xr.DataArray(np.arange(24), dims=["time"])
333+
da = da.assign_coords(day=365 * da)
334+
335+
result = da.coarsen(time=12).construct(time=("year", "month"))
336+
assert list(da.coords) == list(result.coords)
337+
338+
ds = da.to_dataset(name="T")
339+
result = ds.coarsen(time=12).construct(time=("year", "month"))
340+
assert list(da.coords) == list(result.coords)

0 commit comments

Comments
 (0)