Skip to content

Commit e313853

Browse files
authored
flox: don't set fill_value where possible (#9433)
* flox: don't set fill_value where possible Closes #8090 Closes #8206 Closes #9398 * Update doctest * Fix test * fix test * Test for flox >= 0.9.12 * fix whats-new
1 parent 9cb9958 commit e313853

File tree

6 files changed

+73
-21
lines changed

6 files changed

+73
-21
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ Bug fixes
5151
the non-missing times could in theory be encoded with integers
5252
(:issue:`9488`, :pull:`9497`). By `Spencer Clark
5353
<https://github.com/spencerkclark>`_.
54+
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`).
55+
By `Deepak Cherian <https://github.com/dcherian>`_.
5456

5557

5658
Documentation

xarray/core/dataarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6786,8 +6786,8 @@ def groupby(
67866786
67876787
>>> da.groupby("letters").sum()
67886788
<xarray.DataArray (letters: 2, y: 3)> Size: 48B
6789-
array([[ 9., 11., 13.],
6790-
[ 9., 11., 13.]])
6789+
array([[ 9, 11, 13],
6790+
[ 9, 11, 13]])
67916791
Coordinates:
67926792
* letters (letters) object 16B 'a' 'b'
67936793
Dimensions without coordinates: y

xarray/core/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10390,7 +10390,7 @@ def groupby(
1039010390
* letters (letters) object 16B 'a' 'b'
1039110391
Dimensions without coordinates: y
1039210392
Data variables:
10393-
foo (letters, y) float64 48B 9.0 11.0 13.0 9.0 11.0 13.0
10393+
foo (letters, y) int64 48B 9 11 13 9 11 13
1039410394
1039510395
Grouping by multiple variables
1039610396

xarray/core/groupby.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -791,14 +791,12 @@ def _maybe_restore_empty_groups(self, combined):
791791
"""Our index contained empty groups (e.g., from a resampling or binning). If we
792792
reduced on that dimension, we want to restore the full index.
793793
"""
794-
from xarray.groupers import BinGrouper, TimeResampler
795-
794+
has_missing_groups = (
795+
self.encoded.unique_coord.size != self.encoded.full_index.size
796+
)
796797
indexers = {}
797798
for grouper in self.groupers:
798-
if (
799-
isinstance(grouper.grouper, BinGrouper | TimeResampler)
800-
and grouper.name in combined.dims
801-
):
799+
if has_missing_groups and grouper.name in combined._indexes:
802800
indexers[grouper.name] = grouper.full_index
803801
if indexers:
804802
combined = combined.reindex(**indexers)
@@ -853,10 +851,6 @@ def _flox_reduce(
853851
else obj._coords
854852
)
855853

856-
any_isbin = any(
857-
isinstance(grouper.grouper, BinGrouper) for grouper in self.groupers
858-
)
859-
860854
if keep_attrs is None:
861855
keep_attrs = _get_keep_attrs(default=True)
862856

@@ -930,14 +924,14 @@ def _flox_reduce(
930924
):
931925
raise ValueError(f"cannot reduce over dimensions {dim}.")
932926

933-
if kwargs["func"] not in ["all", "any", "count"]:
934-
kwargs.setdefault("fill_value", np.nan)
935-
if any_isbin and kwargs["func"] == "count":
936-
# This is an annoying hack. Xarray returns np.nan
937-
# when there are no observations in a bin, instead of 0.
938-
# We can fake that here by forcing min_count=1.
939-
# note min_count makes no sense in the xarray world
940-
# as a kwarg for count, so this should be OK
927+
has_missing_groups = (
928+
self.encoded.unique_coord.size != self.encoded.full_index.size
929+
)
930+
if has_missing_groups or kwargs.get("min_count", 0) > 0:
931+
# Xarray *always* returns np.nan when there are no observations in a group,
932+
# We can fake that here by forcing min_count=1 when it is not set.
933+
# This handles boolean reductions, and count
934+
# See GH8090, GH9398
941935
kwargs.setdefault("fill_value", np.nan)
942936
kwargs.setdefault("min_count", 1)
943937

xarray/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def _importorskip(
148148
not has_numbagg_or_bottleneck, reason="requires numbagg or bottleneck"
149149
)
150150
has_numpy_2, requires_numpy_2 = _importorskip("numpy", "2.0.0")
151+
_, requires_flox_0_9_12 = _importorskip("flox", "0.9.12")
151152

152153
has_array_api_strict, requires_array_api_strict = _importorskip("array_api_strict")
153154

xarray/tests/test_groupby.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
requires_cftime,
3535
requires_dask,
3636
requires_flox,
37+
requires_flox_0_9_12,
3738
requires_scipy,
3839
)
3940

@@ -2859,6 +2860,60 @@ def test_multiple_groupers_mixed(use_flox) -> None:
28592860
# ------
28602861

28612862

2863+
@requires_flox_0_9_12
2864+
@pytest.mark.parametrize(
2865+
"reduction", ["max", "min", "nanmax", "nanmin", "sum", "nansum", "prod", "nanprod"]
2866+
)
2867+
def test_groupby_preserve_dtype(reduction):
2868+
# all groups are present, we should follow numpy exactly
2869+
ds = xr.Dataset(
2870+
{
2871+
"test": (
2872+
["x", "y"],
2873+
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype="int16"),
2874+
)
2875+
},
2876+
coords={"idx": ("x", [1, 2, 1])},
2877+
)
2878+
2879+
kwargs = {}
2880+
if "nan" in reduction:
2881+
kwargs["skipna"] = True
2882+
# TODO: fix dtype with numbagg/bottleneck and use_flox=False
2883+
with xr.set_options(use_numbagg=False, use_bottleneck=False):
2884+
actual = getattr(ds.groupby("idx"), reduction.removeprefix("nan"))(
2885+
**kwargs
2886+
).test.dtype
2887+
expected = getattr(np, reduction)(ds.test.data, axis=0).dtype
2888+
2889+
assert actual == expected
2890+
2891+
2892+
@requires_dask
2893+
@requires_flox_0_9_12
2894+
@pytest.mark.parametrize("reduction", ["any", "all", "count"])
2895+
def test_gappy_resample_reductions(reduction):
2896+
# GH8090
2897+
dates = (("1988-12-01", "1990-11-30"), ("2000-12-01", "2001-11-30"))
2898+
times = [xr.date_range(*d, freq="D") for d in dates]
2899+
2900+
da = xr.concat(
2901+
[
2902+
xr.DataArray(np.random.rand(len(t)), coords={"time": t}, dims="time")
2903+
for t in times
2904+
],
2905+
dim="time",
2906+
).chunk(time=100)
2907+
2908+
rs = (da > 0.5).resample(time="YS-DEC")
2909+
method = getattr(rs, reduction)
2910+
with xr.set_options(use_flox=True):
2911+
actual = method(dim="time")
2912+
with xr.set_options(use_flox=False):
2913+
expected = method(dim="time")
2914+
assert_identical(expected, actual)
2915+
2916+
28622917
# Possible property tests
28632918
# 1. lambda x: x
28642919
# 2. grouped-reduce on unique coords is identical to array

0 commit comments

Comments
 (0)