Skip to content

Commit 93a4c03

Browse files
ilan-goldshoyerdcherianspencerkclarkkmuehlbauer
authored
(fix): remove _getattr__ method for PandasExtensionArray (pydata#10250)
Co-authored-by: Stephan Hoyer <[email protected]> Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Spencer Clark <[email protected]> Co-authored-by: Kai Mühlbauer <[email protected]> Co-authored-by: Benoit Bovy <[email protected]> Co-authored-by: Kai Muehlbauer <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]>
1 parent b741d8c commit 93a4c03

File tree

11 files changed

+73
-47
lines changed

11 files changed

+73
-47
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ New Features
2929
- Improved compatibility with OPeNDAP DAP4 data model for backend engine ``pydap``. This
3030
includes ``datatree`` support, and removing slashes from dimension names. By
3131
`Miguel Jimenez-Urias <https://github.com/Mikejmnez>`_.
32-
- Improved support pandas Extension Arrays. (:issue:`9661`, :pull:`9671`)
32+
- Improved support pandas categorical extension as indices (i.e., :py:class:`pandas.IntervalIndex`). (:issue:`9661`, :pull:`9671`)
3333
By `Ilan Gold <https://github.com/ilan-gold>`_.
3434
- Improved checks and errors raised when trying to align objects with conflicting indexes.
3535
It is now possible to align objects each with multiple indexes sharing common dimension(s).
@@ -52,6 +52,7 @@ Breaking changes
5252
now return objects indexed by :py:meth:`pandas.IntervalArray` objects,
5353
instead of numpy object arrays containing tuples. This change enables interval-aware indexing of
5454
such Xarray objects. (:pull:`9671`). By `Ilan Gold <https://github.com/ilan-gold>`_.
55+
- Remove ``PandasExtensionArrayIndex`` from :py:attr:`xarray.Variable.data` when the attribute is a :py:class:`pandas.api.extensions.ExtensionArray` (:pull:`10263`). By `Ilan Gold <https://github.com/ilan-gold>`_.
5556
- The html and text ``repr`` for ``DataTree`` are now truncated. Up to 6 children are displayed
5657
for each node -- the first 3 and the last 3 children -- with a ``...`` between them. The number
5758
of children to include in the display is configurable via options. For instance use

xarray/core/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7091,21 +7091,21 @@ def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
70917091
{
70927092
**dict(zip(non_extension_array_columns, data, strict=True)),
70937093
**{
7094-
c: self.variables[c].data.array
7094+
c: self.variables[c].data
70957095
for c in extension_array_columns_same_index
70967096
},
70977097
},
70987098
index=index,
70997099
)
71007100
for extension_array_column in extension_array_columns_different_index:
7101-
extension_array = self.variables[extension_array_column].data.array
7101+
extension_array = self.variables[extension_array_column].data
71027102
index = self[
71037103
self.variables[extension_array_column].dims[0]
71047104
].coords.to_index()
71057105
extension_array_df = pd.DataFrame(
71067106
{extension_array_column: extension_array},
71077107
index=pd.Index(index.array)
7108-
if isinstance(index, PandasExtensionArray)
7108+
if isinstance(index, PandasExtensionArray) # type: ignore[redundant-expr]
71097109
else index,
71107110
)
71117111
extension_array_df.index.name = self.variables[extension_array_column].dims[

xarray/core/extension_array.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable, Sequence
4+
from dataclasses import dataclass
45
from typing import Generic, cast
56

67
import numpy as np
78
import pandas as pd
9+
from packaging.version import Version
810
from pandas.api.types import is_extension_array_dtype
911

1012
from xarray.core.types import DTypeLikeSave, T_ExtensionArray
13+
from xarray.core.utils import NDArrayMixin
1114

1215
HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {}
1316

@@ -33,12 +36,12 @@ def __extension_duck_array__issubdtype(
3336
def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple):
3437
if shape[0] == len(arr) and len(shape) == 1:
3538
return arr
36-
raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.")
39+
raise NotImplementedError("Cannot broadcast 1d-only pandas extension array.")
3740

3841

3942
@implements(np.stack)
4043
def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int):
41-
raise NotImplementedError("Cannot stack 1d-only pandas categorical array.")
44+
raise NotImplementedError("Cannot stack 1d-only pandas extension array.")
4245

4346

4447
@implements(np.concatenate)
@@ -62,21 +65,22 @@ def __extension_duck_array__where(
6265
return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array)
6366

6467

65-
class PandasExtensionArray(Generic[T_ExtensionArray]):
66-
array: T_ExtensionArray
68+
@dataclass(frozen=True)
69+
class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin):
70+
"""NEP-18 compliant wrapper for pandas extension arrays.
71+
72+
Parameters
73+
----------
74+
array : T_ExtensionArray
75+
The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation.
76+
```
77+
"""
6778

68-
def __init__(self, array: T_ExtensionArray):
69-
"""NEP-18 compliant wrapper for pandas extension arrays.
79+
array: T_ExtensionArray
7080

71-
Parameters
72-
----------
73-
array : T_ExtensionArray
74-
The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation.
75-
```
76-
"""
77-
if not isinstance(array, pd.api.extensions.ExtensionArray):
78-
raise TypeError(f"{array} is not an pandas ExtensionArray.")
79-
self.array = array
81+
def __post_init__(self):
82+
if not isinstance(self.array, pd.api.extensions.ExtensionArray):
83+
raise TypeError(f"{self.array} is not an pandas ExtensionArray.")
8084

8185
def __array_function__(self, func, types, args, kwargs):
8286
def replace_duck_with_extension_array(args) -> list:
@@ -105,19 +109,13 @@ def replace_duck_with_extension_array(args) -> list:
105109
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
106110
return ufunc(*inputs, **kwargs)
107111

108-
def __repr__(self):
109-
return f"PandasExtensionArray(array={self.array!r})"
110-
111-
def __getattr__(self, attr: str) -> object:
112-
return getattr(self.array, attr)
113-
114112
def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]:
115113
item = self.array[key]
116114
if is_extension_array_dtype(item):
117-
return type(self)(item)
118-
if np.isscalar(item):
119-
return type(self)(type(self.array)([item])) # type: ignore[call-arg] # only subclasses with proper __init__ allowed
120-
return item
115+
return PandasExtensionArray(item)
116+
if np.isscalar(item) or isinstance(key, int):
117+
return PandasExtensionArray(type(self.array)._from_sequence([item])) # type: ignore[call-arg,attr-defined,unused-ignore]
118+
return PandasExtensionArray(item)
121119

122120
def __setitem__(self, key, val):
123121
self.array[key] = val
@@ -132,3 +130,15 @@ def __ne__(self, other):
132130

133131
def __len__(self):
134132
return len(self.array)
133+
134+
@property
135+
def ndim(self) -> int:
136+
return 1
137+
138+
def __array__(
139+
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
140+
) -> np.ndarray:
141+
if Version(np.__version__) >= Version("2.0.0"):
142+
return np.asarray(self.array, dtype=dtype, copy=copy)
143+
else:
144+
return np.asarray(self.array, dtype=dtype)

xarray/core/formatting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,8 @@ def short_array_repr(array):
626626

627627
if isinstance(array, AbstractArray):
628628
array = array.data
629+
if isinstance(array, pd.api.extensions.ExtensionArray):
630+
return repr(array)
629631
array = to_duck_array(array)
630632

631633
# default to lower precision so a full (abbreviated) line can fit on

xarray/core/indexing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from xarray.core import duck_array_ops
2121
from xarray.core.coordinate_transform import CoordinateTransform
22-
from xarray.core.extension_array import PandasExtensionArray
2322
from xarray.core.nputils import NumpyVIndexAdapter
2423
from xarray.core.options import OPTIONS
2524
from xarray.core.types import T_Xarray
@@ -37,6 +36,7 @@
3736
from xarray.namedarray.pycompat import array_type, integer_types, is_chunked_array
3837

3938
if TYPE_CHECKING:
39+
from xarray.core.extension_array import PandasExtensionArray
4040
from xarray.core.indexes import Index
4141
from xarray.core.types import Self
4242
from xarray.core.variable import Variable
@@ -1797,6 +1797,8 @@ def get_duck_array(self) -> np.ndarray | PandasExtensionArray:
17971797
# We return an PandasExtensionArray wrapper type that satisfies
17981798
# duck array protocols. This is what's needed for tests to pass.
17991799
if pd.api.types.is_extension_array_dtype(self.array):
1800+
from xarray.core.extension_array import PandasExtensionArray
1801+
18001802
return PandasExtensionArray(self.array.array)
18011803
return np.asarray(self)
18021804

xarray/core/variable.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -410,12 +410,20 @@ def data(self):
410410
Variable.as_numpy
411411
Variable.values
412412
"""
413-
if is_duck_array(self._data):
414-
return self._data
413+
if isinstance(self._data, PandasExtensionArray):
414+
duck_array = self._data.array
415415
elif isinstance(self._data, indexing.ExplicitlyIndexed):
416-
return self._data.get_duck_array()
416+
duck_array = self._data.get_duck_array()
417+
elif is_duck_array(self._data):
418+
duck_array = self._data
417419
else:
418-
return self.values
420+
duck_array = self.values
421+
if isinstance(duck_array, PandasExtensionArray):
422+
# even though PandasExtensionArray is a duck array,
423+
# we should not return the PandasExtensionArray wrapper,
424+
# and instead return the underlying data.
425+
return duck_array.array
426+
return duck_array
419427

420428
@data.setter
421429
def data(self, data: T_DuckArray | ArrayLike) -> None:
@@ -1366,7 +1374,7 @@ def set_dims(self, dim, shape=None):
13661374
elif shape is not None:
13671375
dims_map = dict(zip(dim, shape, strict=True))
13681376
tmp_shape = tuple(dims_map[d] for d in expanded_dims)
1369-
expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape)
1377+
expanded_data = duck_array_ops.broadcast_to(self._data, tmp_shape)
13701378
else:
13711379
indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,)
13721380
expanded_data = self.data[indexer]

xarray/tests/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def assert_writeable(ds):
6060
name
6161
for name, var in ds.variables.items()
6262
if not isinstance(var, IndexVariable)
63-
and not isinstance(var.data, PandasExtensionArray)
63+
and not isinstance(
64+
var.data, PandasExtensionArray | pd.api.extensions.ExtensionArray
65+
)
6466
and not var.data.flags.writeable
6567
]
6668
assert not readonly, readonly

xarray/tests/test_concat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,10 @@ def test_concat_categorical() -> None:
160160
concatenated = concat([data1, data2], dim="dim1")
161161
assert (
162162
concatenated["var4"]
163-
== type(data2["var4"].variable.data.array)._concat_same_type(
163+
== type(data2["var4"].variable.data)._concat_same_type(
164164
[
165-
data1["var4"].variable.data.array,
166-
data2["var4"].variable.data.array,
165+
data1["var4"].variable.data,
166+
data2["var4"].variable.data,
167167
]
168168
)
169169
).all()

xarray/tests/test_dataset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,6 +1826,12 @@ def test_categorical_reindex(self) -> None:
18261826
actual = ds.reindex(cat=["foo"])["cat"].values
18271827
assert (actual == np.array(["foo"])).all()
18281828

1829+
def test_extension_array_reindex_same(self) -> None:
1830+
series = pd.Series([1, 2, pd.NA, 3], dtype=pd.Int32Dtype())
1831+
test = xr.Dataset({"test": series})
1832+
res = test.reindex(dim_0=series.index)
1833+
align(res, test, join="exact")
1834+
18291835
def test_categorical_multiindex(self) -> None:
18301836
i1 = pd.Series([0, 0])
18311837
cat = pd.CategoricalDtype(categories=["foo", "baz", "bar"])

xarray/tests/test_duck_array_ops.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ def test_extension_array_pyarrow_concatenate(self, arrow1, arrow2):
196196
concatenated = concatenate(
197197
(PandasExtensionArray(arrow1), PandasExtensionArray(arrow2))
198198
)
199-
assert concatenated[2]["x"] == 3
200-
assert concatenated[3]["y"]
199+
assert concatenated[2].array[0]["x"] == 3
200+
assert concatenated[3].array[0]["y"]
201201

202202
def test___getitem__extension_duck_array(self, categorical1):
203203
extension_duck_array = PandasExtensionArray(categorical1)
@@ -1094,8 +1094,3 @@ def test_extension_array_singleton_equality(categorical1):
10941094
def test_extension_array_repr(int1):
10951095
int_duck_array = PandasExtensionArray(int1)
10961096
assert repr(int1) in repr(int_duck_array)
1097-
1098-
1099-
def test_extension_array_attr(int1):
1100-
int_duck_array = PandasExtensionArray(int1)
1101-
assert (~int_duck_array.fillna(10)).all()

0 commit comments

Comments
 (0)