Skip to content

Commit 65d500c

Browse files
committed
Grouper, Resampler as public API
1 parent 2645d7f commit 65d500c

File tree

7 files changed

+109
-71
lines changed

7 files changed

+109
-71
lines changed

xarray/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from xarray.coding.cftimeindex import CFTimeIndex
1515
from xarray.coding.frequencies import infer_freq
1616
from xarray.conventions import SerializationWarning, decode_cf
17+
from xarray.core import groupers
1718
from xarray.core.alignment import align, broadcast
1819
from xarray.core.combine import combine_by_coords, combine_nested
1920
from xarray.core.common import ALL_DIMS, full_like, ones_like, zeros_like
@@ -94,6 +95,8 @@
9495
"unify_chunks",
9596
"where",
9697
"zeros_like",
98+
# Submodules
99+
"groupers",
97100
# Classes
98101
"CFTimeIndex",
99102
"Context",

xarray/core/common.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
from xarray.core.dataarray import DataArray
4040
from xarray.core.dataset import Dataset
41+
from xarray.core.groupers import Resampler
4142
from xarray.core.indexes import Index
4243
from xarray.core.resample import Resample
4344
from xarray.core.rolling_exp import RollingExp
@@ -876,7 +877,7 @@ def rolling_exp(
876877
def _resample(
877878
self,
878879
resample_cls: type[T_Resample],
879-
indexer: Mapping[Any, str] | None,
880+
indexer: Mapping[Hashable, str | Resampler] | None,
880881
skipna: bool | None,
881882
closed: SideOptions | None,
882883
label: SideOptions | None,
@@ -885,7 +886,7 @@ def _resample(
885886
origin: str | DatetimeLike,
886887
loffset: datetime.timedelta | str | None,
887888
restore_coord_dims: bool | None,
888-
**indexer_kwargs: str,
889+
**indexer_kwargs: str | Resampler,
889890
) -> T_Resample:
890891
"""Returns a Resample object for performing resampling operations.
891892
@@ -1068,7 +1069,7 @@ def _resample(
10681069

10691070
from xarray.core.dataarray import DataArray
10701071
from xarray.core.groupby import ResolvedGrouper
1071-
from xarray.core.groupers import TimeResampler
1072+
from xarray.core.groupers import Resampler, TimeResampler
10721073
from xarray.core.resample import RESAMPLE_DIM
10731074

10741075
# note: the second argument (now 'skipna') use to be 'dim'
@@ -1098,15 +1099,20 @@ def _resample(
10981099
name=RESAMPLE_DIM,
10991100
)
11001101

1101-
grouper = TimeResampler(
1102-
freq=freq,
1103-
closed=closed,
1104-
label=label,
1105-
origin=origin,
1106-
offset=offset,
1107-
loffset=loffset,
1108-
base=base,
1109-
)
1102+
grouper: Resampler
1103+
if isinstance(freq, str):
1104+
grouper = TimeResampler(
1105+
freq=freq,
1106+
closed=closed,
1107+
label=label,
1108+
origin=origin,
1109+
offset=offset,
1110+
loffset=loffset,
1111+
base=base,
1112+
)
1113+
else:
1114+
assert isinstance(freq, Resampler)
1115+
grouper = freq
11101116

11111117
rgrouper = ResolvedGrouper(grouper, group, self)
11121118

xarray/core/dataarray.py

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
from xarray.backends import ZarrStore
8888
from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes
8989
from xarray.core.groupby import DataArrayGroupBy
90+
from xarray.core.groupers import Grouper, Resampler
9091
from xarray.core.resample import DataArrayResample
9192
from xarray.core.rolling import DataArrayCoarsen, DataArrayRolling
9293
from xarray.core.types import (
@@ -6682,9 +6683,10 @@ def interp_calendar(
66826683

66836684
def groupby(
66846685
self,
6685-
group: Hashable | DataArray | IndexVariable,
6686+
group: Hashable | DataArray | IndexVariable | None = None,
66866687
squeeze: bool | None = None,
66876688
restore_coord_dims: bool = False,
6689+
**groupers: Grouper,
66886690
) -> DataArrayGroupBy:
66896691
"""Returns a DataArrayGroupBy object for performing grouped operations.
66906692
@@ -6700,6 +6702,10 @@ def groupby(
67006702
restore_coord_dims : bool, default: False
67016703
If True, also restore the dimension order of multi-dimensional
67026704
coordinates.
6705+
**groupers : Mapping of hashable to Grouper or Resampler
6706+
Mapping of variable name to group by to ``Grouper`` or ``Resampler`` object.
6707+
One of ``group`` or ``groupers`` must be provided.
6708+
Only a single ``grouper`` is allowed at present.
67036709
67046710
Returns
67056711
-------
@@ -6729,6 +6735,15 @@ def groupby(
67296735
* time (time) datetime64[ns] 15kB 2000-01-01 2000-01-02 ... 2004-12-31
67306736
dayofyear (time) int64 15kB 1 2 3 4 5 6 7 8 ... 360 361 362 363 364 365 366
67316737
6738+
Use a ``Grouper`` object to be more explicit
6739+
6740+
>>> da.coords["dayofyear"] = da.time.dt.dayofyear
6741+
>>> da.groupby(dayofyear=xr.groupers.UniqueGrouper()).mean()
6742+
<xarray.DataArray (dayofyear: 366)> Size: 3kB
6743+
array([ 730.8, 731.8, 732.8, ..., 1093.8, 1094.8, 1095.5])
6744+
Coordinates:
6745+
* dayofyear (dayofyear) int64 3kB 1 2 3 4 5 6 7 ... 361 362 363 364 365 366
6746+
67326747
See Also
67336748
--------
67346749
:ref:`groupby`
@@ -6756,7 +6771,19 @@ def groupby(
67566771
from xarray.core.groupers import UniqueGrouper
67576772

67586773
_validate_groupby_squeeze(squeeze)
6759-
rgrouper = ResolvedGrouper(UniqueGrouper(), group, self)
6774+
6775+
if group is not None:
6776+
assert not groupers
6777+
grouper = UniqueGrouper()
6778+
else:
6779+
if len(groupers) > 1:
6780+
raise ValueError("grouping by multiple variables is not supported yet.")
6781+
if not groupers:
6782+
raise ValueError
6783+
group, grouper = next(iter(groupers.items()))
6784+
6785+
rgrouper = ResolvedGrouper(grouper, group, self)
6786+
67606787
return DataArrayGroupBy(
67616788
self,
67626789
(rgrouper,),
@@ -7189,7 +7216,7 @@ def coarsen(
71897216

71907217
def resample(
71917218
self,
7192-
indexer: Mapping[Any, str] | None = None,
7219+
indexer: Mapping[Hashable, str | Resampler] | None = None,
71937220
skipna: bool | None = None,
71947221
closed: SideOptions | None = None,
71957222
label: SideOptions | None = None,
@@ -7198,7 +7225,7 @@ def resample(
71987225
origin: str | DatetimeLike = "start_day",
71997226
loffset: datetime.timedelta | str | None = None,
72007227
restore_coord_dims: bool | None = None,
7201-
**indexer_kwargs: str,
7228+
**indexer_kwargs: str | Resampler,
72027229
) -> DataArrayResample:
72037230
"""Returns a Resample object for performing resampling operations.
72047231
@@ -7291,28 +7318,7 @@ def resample(
72917318
0.48387097, 0.51612903, 0.5483871 , 0.58064516, 0.61290323,
72927319
0.64516129, 0.67741935, 0.70967742, 0.74193548, 0.77419355,
72937320
0.80645161, 0.83870968, 0.87096774, 0.90322581, 0.93548387,
7294-
0.96774194, 1. , 1.03225806, 1.06451613, 1.09677419,
7295-
1.12903226, 1.16129032, 1.19354839, 1.22580645, 1.25806452,
7296-
1.29032258, 1.32258065, 1.35483871, 1.38709677, 1.41935484,
7297-
1.4516129 , 1.48387097, 1.51612903, 1.5483871 , 1.58064516,
7298-
1.61290323, 1.64516129, 1.67741935, 1.70967742, 1.74193548,
7299-
1.77419355, 1.80645161, 1.83870968, 1.87096774, 1.90322581,
7300-
1.93548387, 1.96774194, 2. , 2.03448276, 2.06896552,
7301-
2.10344828, 2.13793103, 2.17241379, 2.20689655, 2.24137931,
7302-
2.27586207, 2.31034483, 2.34482759, 2.37931034, 2.4137931 ,
7303-
2.44827586, 2.48275862, 2.51724138, 2.55172414, 2.5862069 ,
7304-
2.62068966, 2.65517241, 2.68965517, 2.72413793, 2.75862069,
7305-
2.79310345, 2.82758621, 2.86206897, 2.89655172, 2.93103448,
7306-
2.96551724, 3. , 3.03225806, 3.06451613, 3.09677419,
7307-
3.12903226, 3.16129032, 3.19354839, 3.22580645, 3.25806452,
7308-
...
7309-
7.87096774, 7.90322581, 7.93548387, 7.96774194, 8. ,
7310-
8.03225806, 8.06451613, 8.09677419, 8.12903226, 8.16129032,
7311-
8.19354839, 8.22580645, 8.25806452, 8.29032258, 8.32258065,
7312-
8.35483871, 8.38709677, 8.41935484, 8.4516129 , 8.48387097,
7313-
8.51612903, 8.5483871 , 8.58064516, 8.61290323, 8.64516129,
7314-
8.67741935, 8.70967742, 8.74193548, 8.77419355, 8.80645161,
7315-
8.83870968, 8.87096774, 8.90322581, 8.93548387, 8.96774194,
7321+
0.96774194, 1. , ...,
73167322
9. , 9.03333333, 9.06666667, 9.1 , 9.13333333,
73177323
9.16666667, 9.2 , 9.23333333, 9.26666667, 9.3 ,
73187324
9.33333333, 9.36666667, 9.4 , 9.43333333, 9.46666667,
@@ -7342,19 +7348,7 @@ def resample(
73427348
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 3.,
73437349
3., 3., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
73447350
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
7345-
nan, nan, nan, nan, 4., 4., 4., nan, nan, nan, nan, nan, nan,
7346-
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
7347-
nan, nan, nan, nan, nan, nan, nan, nan, 5., 5., 5., nan, nan,
7348-
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
7349-
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
7350-
6., 6., 6., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
7351-
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
7352-
nan, nan, nan, nan, 7., 7., 7., nan, nan, nan, nan, nan, nan,
7353-
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
7354-
nan, nan, nan, nan, nan, nan, nan, nan, nan, 8., 8., 8., nan,
7355-
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
7356-
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
7357-
nan, 9., 9., 9., nan, nan, nan, nan, nan, nan, nan, nan, nan,
7351+
nan, nan, nan, nan, 4., 4., 4., nan, nan, nan, nan, nan, ...,
73587352
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
73597353
nan, nan, nan, nan, nan, 10., 10., 10., nan, nan, nan, nan, nan,
73607354
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,

xarray/core/dataset.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes
138138
from xarray.core.dataarray import DataArray
139139
from xarray.core.groupby import DatasetGroupBy
140+
from xarray.core.groupers import Grouper, Resampler
140141
from xarray.core.merge import CoercibleMapping, CoercibleValue, _MergeResult
141142
from xarray.core.resample import DatasetResample
142143
from xarray.core.rolling import DatasetCoarsen, DatasetRolling
@@ -10256,9 +10257,10 @@ def interp_calendar(
1025610257

1025710258
def groupby(
1025810259
self,
10259-
group: Hashable | DataArray | IndexVariable,
10260+
group: Hashable | DataArray | IndexVariable | None = None,
1026010261
squeeze: bool | None = None,
1026110262
restore_coord_dims: bool = False,
10263+
**groupers: Grouper,
1026210264
) -> DatasetGroupBy:
1026310265
"""Returns a DatasetGroupBy object for performing grouped operations.
1026410266
@@ -10274,6 +10276,10 @@ def groupby(
1027410276
restore_coord_dims : bool, default: False
1027510277
If True, also restore the dimension order of multi-dimensional
1027610278
coordinates.
10279+
**groupers : Mapping of hashable to Grouper or Resampler
10280+
Mapping of variable name to group by to ``Grouper`` or ``Resampler`` object.
10281+
One of ``group`` or ``groupers`` must be provided.
10282+
Only a single ``grouper`` is allowed at present.
1027710283
1027810284
Returns
1027910285
-------
@@ -10308,7 +10314,16 @@ def groupby(
1030810314
from xarray.core.groupers import UniqueGrouper
1030910315

1031010316
_validate_groupby_squeeze(squeeze)
10311-
rgrouper = ResolvedGrouper(UniqueGrouper(), group, self)
10317+
if group is not None:
10318+
assert not groupers
10319+
rgrouper = ResolvedGrouper(UniqueGrouper(), group, self)
10320+
else:
10321+
if len(groupers) > 1:
10322+
raise ValueError("grouping by multiple variables is not supported yet.")
10323+
if not groupers:
10324+
raise ValueError
10325+
for group, grouper in groupers.items():
10326+
rgrouper = ResolvedGrouper(grouper, group, self)
1031210327

1031310328
return DatasetGroupBy(
1031410329
self,
@@ -10587,7 +10602,7 @@ def coarsen(
1058710602

1058810603
def resample(
1058910604
self,
10590-
indexer: Mapping[Any, str] | None = None,
10605+
indexer: Mapping[Hashable, str | Resampler] | None = None,
1059110606
skipna: bool | None = None,
1059210607
closed: SideOptions | None = None,
1059310608
label: SideOptions | None = None,
@@ -10596,7 +10611,7 @@ def resample(
1059610611
origin: str | DatetimeLike = "start_day",
1059710612
loffset: datetime.timedelta | str | None = None,
1059810613
restore_coord_dims: bool | None = None,
10599-
**indexer_kwargs: str,
10614+
**indexer_kwargs: str | Resampler,
1060010615
) -> DatasetResample:
1060110616
"""Returns a Resample object for performing resampling operations.
1060210617

xarray/core/groupby.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def attrs(self) -> dict:
226226

227227
def __getitem__(self, key):
228228
if isinstance(key, tuple):
229-
key = key[0]
229+
(key,) = key
230230
return self.values[key]
231231

232232
def to_index(self) -> pd.Index:

xarray/core/groupers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from abc import ABC, abstractmethod
1111
from collections.abc import Mapping, Sequence
1212
from dataclasses import dataclass, field
13-
from typing import Any
1413

1514
import numpy as np
1615
import pandas as pd

xarray/tests/test_groupby.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212

1313
import xarray as xr
1414
from xarray import DataArray, Dataset, Variable
15-
from xarray.core.groupby import _consolidate_slices
15+
from xarray.core.groupby import (
16+
BinGrouper,
17+
UniqueGrouper,
18+
_consolidate_slices,
19+
)
1620
from xarray.core.types import InterpOptions
1721
from xarray.tests import (
1822
InaccessibleArray,
@@ -113,8 +117,9 @@ def test_multi_index_groupby_map(dataset) -> None:
113117
assert_equal(expected, actual)
114118

115119

116-
def test_reduce_numeric_only(dataset) -> None:
117-
gb = dataset.groupby("x", squeeze=False)
120+
@pytest.mark.parametrize("grouper", [dict(group="x"), dict(x=UniqueGrouper())])
121+
def test_reduce_numeric_only(dataset, grouper: dict) -> None:
122+
gb = dataset.groupby(**grouper, squeeze=False)
118123
with xr.set_options(use_flox=False):
119124
expected = gb.sum()
120125
with xr.set_options(use_flox=True):
@@ -883,11 +888,12 @@ def test_groupby_dataset_reduce() -> None:
883888

884889
expected = data.mean("y")
885890
expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3})
886-
actual = data.groupby("x").mean(...)
887-
assert_allclose(expected, actual)
891+
for gb in [data.groupby("x"), data.groupby(x=UniqueGrouper())]:
892+
actual = gb.mean(...)
893+
assert_allclose(expected, actual)
888894

889-
actual = data.groupby("x").mean("y")
890-
assert_allclose(expected, actual)
895+
actual = gb.mean("y")
896+
assert_allclose(expected, actual)
891897

892898
letters = data["letters"]
893899
expected = Dataset(
@@ -897,8 +903,9 @@ def test_groupby_dataset_reduce() -> None:
897903
"yonly": data["yonly"].groupby(letters).mean(),
898904
}
899905
)
900-
actual = data.groupby("letters").mean(...)
901-
assert_allclose(expected, actual)
906+
for gb in [data.groupby("letters"), data.groupby(letters=UniqueGrouper())]:
907+
actual = gb.mean(...)
908+
assert_allclose(expected, actual)
902909

903910

904911
@pytest.mark.parametrize("squeeze", [True, False])
@@ -1028,6 +1035,14 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None:
10281035
)
10291036
assert_identical(expected, actual)
10301037

1038+
with xr.set_options(use_flox=use_flox):
1039+
actual = da.groupby(
1040+
x=BinGrouper(
1041+
bins=x_bins, cut_kwargs=dict(include_lowest=True, right=False)
1042+
),
1043+
).mean()
1044+
assert_identical(expected, actual)
1045+
10311046

10321047
@pytest.mark.parametrize("indexed_coord", [True, False])
10331048
def test_groupby_bins_math(indexed_coord) -> None:
@@ -1036,11 +1051,17 @@ def test_groupby_bins_math(indexed_coord) -> None:
10361051
if indexed_coord:
10371052
da["x"] = np.arange(N)
10381053
da["y"] = np.arange(N)
1039-
g = da.groupby_bins("x", np.arange(0, N + 1, 3))
1040-
mean = g.mean()
1041-
expected = da.isel(x=slice(1, None)) - mean.isel(x_bins=("x", [0, 0, 0, 1, 1, 1]))
1042-
actual = g - mean
1043-
assert_identical(expected, actual)
1054+
1055+
for g in [
1056+
da.groupby_bins("x", np.arange(0, N + 1, 3)),
1057+
da.groupby(x=BinGrouper(bins=np.arange(0, N + 1, 3))),
1058+
]:
1059+
mean = g.mean()
1060+
expected = da.isel(x=slice(1, None)) - mean.isel(
1061+
x_bins=("x", [0, 0, 0, 1, 1, 1])
1062+
)
1063+
actual = g - mean
1064+
assert_identical(expected, actual)
10441065

10451066

10461067
def test_groupby_math_nD_group() -> None:

0 commit comments

Comments
 (0)