Skip to content

Commit db2d414

Browse files
authored
pandas=2.0 support (#7724)
* unpin pandas in the package metadata * unpin pandas in the ci environments [skip-rtd] * create the input arrays in the parametrization * split `test_sel_float` into variants * skip the float16 variant if pandas>=2.0 is installed * [skip-rtd] * add tests for `days_in_month` and its alias * make sure the name and dtype match the expected * actually verify that the dtype stays the same * apply the dtype for non-dask * always use `int32` to follow `pandas=2.0` * back to `int64` * same for the test * final undo of `int64` → `int32` * update the comment to make more sense * simplify the conversion of the expected data * change back to the old condition * cast float16 to float64 when creating indexes (but warn anyways) * convert float16 to float64 when selecting using arrays (all except 0d) * move the float16 variant to a separate test this allows us to check for expected warnings * explicitly type the kwargs as a mapping of str → str * reword the warning message * restore the pin * [skip-ci] [skip-rtd] * rerun to make sure we don't introduce failures with `pandas<2` * changelog
1 parent 7c7e383 commit db2d414

File tree

6 files changed

+90
-31
lines changed

6 files changed

+90
-31
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ New Features
3030
- Added ability to save ``DataArray`` objects directly to Zarr using :py:meth:`~xarray.DataArray.to_zarr`.
3131
(:issue:`7692`, :pull:`7693`) .
3232
By `Joe Hamman <https://github.com/jhamman>`_.
33+
- Support `pandas>=2.0` (:pull:`7724`)
34+
By `Justus Magin <https://github.com/keewis>`_.
3335

3436
Breaking changes
3537
~~~~~~~~~~~~~~~~

xarray/core/accessor_dt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _get_date_field(values, name, dtype):
115115
access_method, values, name, dtype=dtype, new_axis=new_axis, chunks=chunks
116116
)
117117
else:
118-
return access_method(values, name)
118+
return access_method(values, name).astype(dtype)
119119

120120

121121
def _round_through_series_or_index(values, name, freq):

xarray/core/indexes.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
PandasIndexingAdapter,
1616
PandasMultiIndexingAdapter,
1717
)
18-
from xarray.core.utils import Frozen, get_valid_numpy_dtype, is_dict_like, is_scalar
18+
from xarray.core.utils import (
19+
Frozen,
20+
emit_user_level_warning,
21+
get_valid_numpy_dtype,
22+
is_dict_like,
23+
is_scalar,
24+
)
1925

2026
if TYPE_CHECKING:
2127
from xarray.core.types import ErrorOptions, T_Index
@@ -166,9 +172,21 @@ def safe_cast_to_index(array: Any) -> pd.Index:
166172
elif isinstance(array, PandasIndexingAdapter):
167173
index = array.array
168174
else:
169-
kwargs = {}
170-
if hasattr(array, "dtype") and array.dtype.kind == "O":
171-
kwargs["dtype"] = object
175+
kwargs: dict[str, str] = {}
176+
if hasattr(array, "dtype"):
177+
if array.dtype.kind == "O":
178+
kwargs["dtype"] = "object"
179+
elif array.dtype == "float16":
180+
emit_user_level_warning(
181+
(
182+
"`pandas.Index` does not support the `float16` dtype."
183+
" Casting to `float64` for you, but in the future please"
184+
" manually cast to either `float32` and `float64`."
185+
),
186+
category=DeprecationWarning,
187+
)
188+
kwargs["dtype"] = "float64"
189+
172190
index = pd.Index(np.asarray(array), **kwargs)
173191

174192
return _maybe_cast_to_cftimeindex(index)
@@ -259,6 +277,8 @@ def get_indexer_nd(index, labels, method=None, tolerance=None):
259277
labels
260278
"""
261279
flat_labels = np.ravel(labels)
280+
if flat_labels.dtype == "float16":
281+
flat_labels = flat_labels.astype("float64")
262282
flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance)
263283
indexer = flat_indexer.reshape(labels.shape)
264284
return indexer

xarray/tests/test_accessor_dt.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def setup(self):
5959
"quarter",
6060
"date",
6161
"time",
62+
"daysinmonth",
63+
"days_in_month",
6264
"is_month_start",
6365
"is_month_end",
6466
"is_quarter_start",
@@ -74,7 +76,18 @@ def test_field_access(self, field) -> None:
7476
else:
7577
data = getattr(self.times, field)
7678

77-
expected = xr.DataArray(data, name=field, coords=[self.times], dims=["time"])
79+
if data.dtype.kind != "b" and field not in ("date", "time"):
80+
# pandas 2.0 returns int32 for integer fields now
81+
data = data.astype("int64")
82+
83+
translations = {
84+
"weekday": "dayofweek",
85+
"daysinmonth": "days_in_month",
86+
"weekofyear": "week",
87+
}
88+
name = translations.get(field, field)
89+
90+
expected = xr.DataArray(data, name=name, coords=[self.times], dims=["time"])
7891

7992
if field in ["week", "weekofyear"]:
8093
with pytest.warns(
@@ -84,7 +97,8 @@ def test_field_access(self, field) -> None:
8497
else:
8598
actual = getattr(self.data.time.dt, field)
8699

87-
assert_equal(expected, actual)
100+
assert expected.dtype == actual.dtype
101+
assert_identical(expected, actual)
88102

89103
@pytest.mark.parametrize(
90104
"field, pandas_field",

xarray/tests/test_dataarray.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,32 +1023,53 @@ def test_sel_dataarray_datetime_slice(self) -> None:
10231023
result = array.sel(delta=slice(array.delta[0], array.delta[-1]))
10241024
assert_equal(result, array)
10251025

1026-
def test_sel_float(self) -> None:
1026+
@pytest.mark.parametrize(
1027+
["coord_values", "indices"],
1028+
(
1029+
pytest.param(
1030+
np.array([0.0, 0.111, 0.222, 0.333], dtype="float64"),
1031+
slice(1, 3),
1032+
id="float64",
1033+
),
1034+
pytest.param(
1035+
np.array([0.0, 0.111, 0.222, 0.333], dtype="float32"),
1036+
slice(1, 3),
1037+
id="float32",
1038+
),
1039+
pytest.param(
1040+
np.array([0.0, 0.111, 0.222, 0.333], dtype="float32"), [2], id="scalar"
1041+
),
1042+
),
1043+
)
1044+
def test_sel_float(self, coord_values, indices) -> None:
10271045
data_values = np.arange(4)
10281046

1029-
# case coords are float32 and label is list of floats
1030-
float_values = [0.0, 0.111, 0.222, 0.333]
1031-
coord_values = np.asarray(float_values, dtype="float32")
1032-
array = DataArray(data_values, [("float32_coord", coord_values)])
1033-
expected = DataArray(data_values[1:3], [("float32_coord", coord_values[1:3])])
1034-
actual = array.sel(float32_coord=float_values[1:3])
1035-
# case coords are float16 and label is list of floats
1036-
coord_values_16 = np.asarray(float_values, dtype="float16")
1037-
expected_16 = DataArray(
1038-
data_values[1:3], [("float16_coord", coord_values_16[1:3])]
1039-
)
1040-
array_16 = DataArray(data_values, [("float16_coord", coord_values_16)])
1041-
actual_16 = array_16.sel(float16_coord=float_values[1:3])
1047+
arr = DataArray(data_values, coords={"x": coord_values}, dims="x")
10421048

1043-
# case coord, label are scalars
1044-
expected_scalar = DataArray(
1045-
data_values[2], coords={"float32_coord": coord_values[2]}
1049+
actual = arr.sel(x=coord_values[indices])
1050+
expected = DataArray(
1051+
data_values[indices], coords={"x": coord_values[indices]}, dims="x"
10461052
)
1047-
actual_scalar = array.sel(float32_coord=float_values[2])
10481053

1049-
assert_equal(expected, actual)
1050-
assert_equal(expected_scalar, actual_scalar)
1051-
assert_equal(expected_16, actual_16)
1054+
assert_equal(actual, expected)
1055+
1056+
def test_sel_float16(self) -> None:
1057+
data_values = np.arange(4)
1058+
coord_values = np.array([0.0, 0.111, 0.222, 0.333], dtype="float16")
1059+
indices = slice(1, 3)
1060+
1061+
message = "`pandas.Index` does not support the `float16` dtype.*"
1062+
1063+
with pytest.warns(DeprecationWarning, match=message):
1064+
arr = DataArray(data_values, coords={"x": coord_values}, dims="x")
1065+
with pytest.warns(DeprecationWarning, match=message):
1066+
expected = DataArray(
1067+
data_values[indices], coords={"x": coord_values[indices]}, dims="x"
1068+
)
1069+
1070+
actual = arr.sel(x=coord_values[indices])
1071+
1072+
assert_equal(actual, expected)
10521073

10531074
def test_sel_float_multiindex(self) -> None:
10541075
# regression test https://github.com/pydata/xarray/issues/5691

xarray/tests/test_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@ def new_method():
2323

2424

2525
@pytest.mark.parametrize(
26-
"a, b, expected", [["a", "b", np.array(["a", "b"])], [1, 2, pd.Index([1, 2])]]
26+
["a", "b", "expected"],
27+
[
28+
[np.array(["a"]), np.array(["b"]), np.array(["a", "b"])],
29+
[np.array([1], dtype="int64"), np.array([2], dtype="int64"), pd.Index([1, 2])],
30+
],
2731
)
2832
def test_maybe_coerce_to_str(a, b, expected):
29-
a = np.array([a])
30-
b = np.array([b])
3133
index = pd.Index(a).append(pd.Index(b))
3234

3335
actual = utils.maybe_coerce_to_str(index, [a, b])

0 commit comments

Comments
 (0)