Skip to content

Commit e250895

Browse files
committed
Better binning API
1 parent 01fbf50 commit e250895

File tree

4 files changed

+64
-22
lines changed

4 files changed

+64
-22
lines changed

xarray/core/dataarray.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6805,6 +6805,7 @@ def groupby_bins(
68056805
include_lowest: bool = False,
68066806
squeeze: bool | None = None,
68076807
restore_coord_dims: bool = False,
6808+
duplicates: Literal["raise", "drop"] = "raise",
68086809
) -> DataArrayGroupBy:
68096810
"""Returns a DataArrayGroupBy object for performing grouped operations.
68106811
@@ -6841,6 +6842,8 @@ def groupby_bins(
68416842
restore_coord_dims : bool, default: False
68426843
If True, also restore the dimension order of multi-dimensional
68436844
coordinates.
6845+
duplicates : {default 'raise', 'drop'}, optional
6846+
If bin edges are not unique, raise ValueError or drop non-uniques.
68446847
68456848
Returns
68466849
-------
@@ -6873,12 +6876,10 @@ def groupby_bins(
68736876
_validate_groupby_squeeze(squeeze)
68746877
grouper = BinGrouper(
68756878
bins=bins,
6876-
cut_kwargs={
6877-
"right": right,
6878-
"labels": labels,
6879-
"precision": precision,
6880-
"include_lowest": include_lowest,
6881-
},
6879+
right=right,
6880+
labels=labels,
6881+
precision=precision,
6882+
include_lowest=include_lowest,
68826883
)
68836884
rgrouper = ResolvedGrouper(grouper, group, self)
68846885

xarray/core/dataset.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10342,6 +10342,7 @@ def groupby_bins(
1034210342
include_lowest: bool = False,
1034310343
squeeze: bool | None = None,
1034410344
restore_coord_dims: bool = False,
10345+
duplicates: Literal["raise", "drop"] = "raise",
1034510346
) -> DatasetGroupBy:
1034610347
"""Returns a DatasetGroupBy object for performing grouped operations.
1034710348
@@ -10378,6 +10379,8 @@ def groupby_bins(
1037810379
restore_coord_dims : bool, default: False
1037910380
If True, also restore the dimension order of multi-dimensional
1038010381
coordinates.
10382+
duplicates : {default 'raise', 'drop'}, optional
10383+
If bin edges are not unique, raise ValueError or drop non-uniques.
1038110384
1038210385
Returns
1038310386
-------
@@ -10410,12 +10413,10 @@ def groupby_bins(
1041010413
_validate_groupby_squeeze(squeeze)
1041110414
grouper = BinGrouper(
1041210415
bins=bins,
10413-
cut_kwargs={
10414-
"right": right,
10415-
"labels": labels,
10416-
"precision": precision,
10417-
"include_lowest": include_lowest,
10418-
},
10416+
right=right,
10417+
labels=labels,
10418+
precision=precision,
10419+
include_lowest=include_lowest,
1041910420
)
1042010421
rgrouper = ResolvedGrouper(grouper, group, self)
1042110422

xarray/core/groupers.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88

99
import datetime
1010
from abc import ABC, abstractmethod
11-
from collections.abc import Mapping, Sequence
11+
from collections.abc import Sequence
1212
from dataclasses import dataclass, field
13+
from typing import Any, Literal
1314

1415
import numpy as np
1516
import pandas as pd
@@ -195,14 +196,46 @@ class BinGrouper(Grouper):
195196
196197
Attributes
197198
----------
198-
bins: int, sequence of scalars, or IntervalIndex
199-
Speciication for bins either as integer, or as bin edges.
200-
cut_kwargs: dict
201-
Keyword arguments forwarded to :py:func:`pandas.cut`.
199+
bins : int, sequence of scalars, or IntervalIndex
200+
The criteria to bin by.
201+
202+
* int : Defines the number of equal-width bins in the range of `x`. The
203+
range of `x` is extended by .1% on each side to include the minimum
204+
and maximum values of `x`.
205+
* sequence of scalars : Defines the bin edges allowing for non-uniform
206+
width. No extension of the range of `x` is done.
207+
* IntervalIndex : Defines the exact bins to be used. Note that
208+
IntervalIndex for `bins` must be non-overlapping.
209+
210+
right : bool, default True
211+
Indicates whether `bins` includes the rightmost edge or not. If
212+
``right == True`` (the default), then the `bins` ``[1, 2, 3, 4]``
213+
indicate (1,2], (2,3], (3,4]. This argument is ignored when
214+
`bins` is an IntervalIndex.
215+
labels : array or False, default None
216+
Specifies the labels for the returned bins. Must be the same length as
217+
the resulting bins. If False, returns only integer indicators of the
218+
bins. This affects the type of the output container (see below).
219+
This argument is ignored when `bins` is an IntervalIndex. If True,
220+
raises an error. When `ordered=False`, labels must be provided.
221+
retbins : bool, default False
222+
Whether to return the bins or not. Useful when bins is provided
223+
as a scalar.
224+
precision : int, default 3
225+
The precision at which to store and display the bins labels.
226+
include_lowest : bool, default False
227+
Whether the first interval should be left-inclusive or not.
228+
duplicates : {default 'raise', 'drop'}, optional
229+
If bin edges are not unique, raise ValueError or drop non-uniques.
202230
"""
203231

204232
bins: int | Sequence | pd.IntervalIndex
205-
cut_kwargs: Mapping = field(default_factory=dict)
233+
# The rest are copied from pandas
234+
right: bool = True
235+
labels: Any = None
236+
precision: int = 3
237+
include_lowest: bool = False
238+
duplicates: Literal["raise", "drop"] = "raise"
206239

207240
def __post_init__(self) -> None:
208241
if duck_array_ops.isnull(self.bins).all():
@@ -213,7 +246,16 @@ def factorize(self, group: T_Group) -> EncodedGroups:
213246

214247
data = group.data
215248

216-
binned, self.bins = pd.cut(data, self.bins, **self.cut_kwargs, retbins=True)
249+
binned, self.bins = pd.cut(
250+
data,
251+
bins=self.bins,
252+
right=self.right,
253+
labels=self.labels,
254+
precision=self.precision,
255+
include_lowest=self.include_lowest,
256+
duplicates=self.duplicates,
257+
retbins=True,
258+
)
217259

218260
binned_codes = binned.codes
219261
if (binned_codes == -1).all():

xarray/tests/test_groupby.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,9 +1034,7 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None:
10341034

10351035
with xr.set_options(use_flox=use_flox):
10361036
actual = da.groupby(
1037-
x=BinGrouper(
1038-
bins=x_bins, cut_kwargs=dict(include_lowest=True, right=False)
1039-
),
1037+
x=BinGrouper(bins=x_bins, include_lowest=True, right=False),
10401038
).mean()
10411039
assert_identical(expected, actual)
10421040

0 commit comments

Comments
 (0)