-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
(feat): Support for pandas
ExtensionArray
#8723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 40 commits
b2712f1
47bddd2
dc8b788
75524c8
c9ab452
1f3d0fa
8a70e3c
f5a6505
08a4feb
d5b218b
00256fa
b7ddbd6
a165851
a826edd
fde19ea
4c55707
58ba17d
a255310
4e78b7e
d9cedf5
426664d
22ca77d
f32cfdf
60f8927
ff22d76
2153e81
b6d0b31
d285871
d847277
8238c64
1260cd4
b04ef98
b9937bf
0bba03f
b714549
a3a678c
e521844
2d3e930
04c9969
5514539
bedfa5c
e6c2690
82dbda9
12217ed
dd5b87d
761a874
52cabc8
e0d58fa
c1e0e64
17e3390
dd2ef39
c8e6bfe
b2a9517
f5e1bd0
407fad1
3a47f09
fdd3de4
6b23629
1c9047f
9be6b03
d9304f1
6ec6725
bc9ac4c
1e906db
6fb8668
8f034b4
90a6de6
2bd422a
ff67943
661d9f2
caee1c6
1d12f5e
31dfbb5
23b347f
902c74b
0b64506
0c7e023
dd7fe98
f0df768
e2f0487
1eb6741
2a7300a
9cceadc
f2588c1
a0a63bd
5bb2bde
f85f166
7ecdeba
6bc40fc
e9dc53f
4791799
c649362
fc60dcf
0374086
b9515a6
72bf807
63b6c42
1d18439
17f05da
c906c81
e6db83b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload | ||
|
||
import numpy as np | ||
from pandas.api.types import is_extension_array_dtype | ||
|
||
# remove once numpy 2.0 is the oldest supported version | ||
try: | ||
|
@@ -6835,6 +6836,7 @@ def reduce( | |
# that don't have the reduce dims: PR5393 | ||
not reduce_dims | ||
or not numeric_only | ||
or not is_extension_array_dtype(var.dtype) | ||
or np.issubdtype(var.dtype, np.number) | ||
or (var.dtype == np.bool_) | ||
): | ||
|
@@ -7149,13 +7151,33 @@ def to_pandas(self) -> pd.Series | pd.DataFrame: | |
) | ||
|
||
def _to_dataframe(self, ordered_dims: Mapping[Any, int]): | ||
columns = [k for k in self.variables if k not in self.dims] | ||
columns = [ | ||
k | ||
for k in self.variables | ||
if k not in self.dims | ||
and not is_extension_array_dtype(self.variables[k].data) | ||
|
||
] | ||
extension_array_columns = [ | ||
k | ||
for k in self.variables | ||
if k not in self.dims and is_extension_array_dtype(self.variables[k].data) | ||
] | ||
data = [ | ||
self._variables[k].set_dims(ordered_dims).values.reshape(-1) | ||
for k in columns | ||
] | ||
index = self.coords.to_index([*ordered_dims]) | ||
return pd.DataFrame(dict(zip(columns, data)), index=index) | ||
broadcasted_df = pd.DataFrame(dict(zip(columns, data)), index=index) | ||
for extension_array_column in extension_array_columns: | ||
extension_array = self.variables[extension_array_column].data.array | ||
index = self[self.variables[extension_array_column].dims[0]].data | ||
cat_df = pd.DataFrame( | ||
{extension_array_column: extension_array}, | ||
index=self[self.variables[extension_array_column].dims[0]].data, | ||
) | ||
cat_df.index.name = self.variables[extension_array_column].dims[0] | ||
broadcasted_df = broadcasted_df.join(cat_df) | ||
return broadcasted_df | ||
|
||
def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame: | ||
"""Convert this dataset into a pandas.DataFrame. | ||
|
@@ -7301,11 +7323,14 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: | |
"cannot convert a DataFrame with a non-unique MultiIndex into xarray" | ||
) | ||
|
||
# Cast to a NumPy array first, in case the Series is a pandas Extension | ||
# array (which doesn't have a valid NumPy dtype) | ||
# TODO: allow users to control how this casting happens, e.g., by | ||
# forwarding arguments to pandas.Series.to_numpy? | ||
arrays = [(k, np.asarray(v)) for k, v in dataframe.items()] | ||
arrays = [ | ||
ilan-gold marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
(k, np.asarray(v)) | ||
for k, v in dataframe.items() | ||
if not is_extension_array_dtype(v) | ||
] | ||
extension_arrays = [ | ||
(k, v) for k, v in dataframe.items() if is_extension_array_dtype(v) | ||
] | ||
|
||
indexes: dict[Hashable, Index] = {} | ||
index_vars: dict[Hashable, Variable] = {} | ||
|
@@ -7319,6 +7344,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: | |
xr_idx = PandasIndex(lev, dim) | ||
indexes[dim] = xr_idx | ||
index_vars.update(xr_idx.create_variables()) | ||
arrays += [(k, np.asarray(v)) for k, v in extension_arrays] | ||
extension_arrays = [] | ||
else: | ||
index_name = idx.name if idx.name is not None else "index" | ||
dims = (index_name,) | ||
|
@@ -7332,6 +7359,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: | |
obj._set_sparse_data_from_dataframe(idx, arrays, dims) | ||
else: | ||
obj._set_numpy_data_from_dataframe(idx, arrays, dims) | ||
for name, extension_array in extension_arrays: | ||
obj[name] = (dims, extension_array) | ||
return obj | ||
|
||
def to_dask_dataframe( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,8 +10,10 @@ | |
import datetime | ||
import inspect | ||
import warnings | ||
from collections.abc import Sequence | ||
from functools import partial | ||
from importlib import import_module | ||
from typing import Callable | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
@@ -32,11 +34,21 @@ | |
from numpy import concatenate as _concatenate | ||
from numpy.lib.stride_tricks import sliding_window_view # noqa | ||
from packaging.version import Version | ||
from pandas.api.types import is_extension_array_dtype | ||
|
||
try: | ||
from plum import dispatch # type: ignore[import-not-found] | ||
except ImportError: | ||
|
||
def dispatch(func): | ||
return func | ||
|
||
|
||
from xarray.core import dask_array_ops, dtypes, nputils, pycompat | ||
from xarray.core.options import OPTIONS | ||
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array | ||
from xarray.core.pycompat import array_type, is_duck_dask_array | ||
from xarray.core.types import DTypeLikeSave, T_ExtensionArray | ||
from xarray.core.utils import is_duck_array, module_available | ||
|
||
# remove once numpy 2.0 is the oldest supported version | ||
|
@@ -53,6 +65,64 @@ | |
dask_available = module_available("dask") | ||
|
||
|
||
HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} | ||
|
||
|
||
def implements(numpy_function): | ||
dcherian marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
"""Register an __array_function__ implementation for MyArray objects.""" | ||
|
||
def decorator(func): | ||
HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func | ||
return func | ||
|
||
return decorator | ||
|
||
|
||
@implements(np.issubdtype) | ||
@dispatch | ||
def __extension_duck_array__issubdtype( | ||
extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave | ||
) -> bool: | ||
return False # never want a function to think a pandas extension dtype is a subtype of numpy | ||
|
||
|
||
@implements(np.broadcast_to) | ||
@dispatch | ||
def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): | ||
|
||
if shape[0] == len(arr) and len(shape) == 1: | ||
return arr | ||
raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.") | ||
|
||
|
||
@implements(np.stack) | ||
@dispatch | ||
def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): | ||
raise NotImplementedError("Cannot stack 1d-only pandas categorical array.") | ||
|
||
|
||
@implements(np.concatenate) | ||
@dispatch | ||
ilan-gold marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
def __extension_duck_array__concatenate( | ||
arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None | ||
) -> T_ExtensionArray: | ||
return type(arrays[0])._concat_same_type(arrays) | ||
|
||
|
||
@implements(np.where) | ||
@dispatch | ||
def __extension_duck_array__where( | ||
condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray | ||
) -> T_ExtensionArray: | ||
if ( | ||
isinstance(x, pd.Categorical) | ||
and isinstance(y, pd.Categorical) | ||
and x.dtype != y.dtype | ||
): | ||
x = x.add_categories(set(y.categories).difference(set(x.categories))) | ||
y = y.add_categories(set(x.categories).difference(set(y.categories))) | ||
return pd.Series(x).where(condition, pd.Series(y)).array | ||
|
||
|
||
def get_array_namespace(x): | ||
if hasattr(x, "__array_namespace__"): | ||
return x.__array_namespace__() | ||
|
@@ -155,7 +225,7 @@ def isnull(data): | |
return full_like(data, dtype=bool, fill_value=False) | ||
else: | ||
# at this point, array should have dtype=object | ||
if isinstance(data, np.ndarray): | ||
if isinstance(data, np.ndarray) or is_extension_array_dtype(data): | ||
return pandas_isnull(data) | ||
else: | ||
# Not reachable yet, but intended for use with other duck array | ||
|
@@ -220,9 +290,17 @@ def asarray(data, xp=np): | |
|
||
def as_shared_dtype(scalars_or_arrays, xp=np): | ||
"""Cast a arrays to a shared dtype using xarray's type promotion rules.""" | ||
array_type_cupy = array_type("cupy") | ||
if array_type_cupy and any( | ||
isinstance(x, array_type_cupy) for x in scalars_or_arrays | ||
if any(is_extension_array_dtype(x) for x in scalars_or_arrays): | ||
extension_array_types = [ | ||
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) | ||
] | ||
if len(extension_array_types) == len(scalars_or_arrays) and all( | ||
isinstance(x, type(extension_array_types[0])) for x in extension_array_types | ||
): | ||
return scalars_or_arrays | ||
ilan-gold marked this conversation as resolved.
Show resolved
Hide resolved
|
||
arrays = [asarray(np.array(x), xp=xp) for x in scalars_or_arrays] | ||
elif array_type_cupy := array_type("cupy") and any( # noqa: F841 | ||
isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821 | ||
ilan-gold marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
import cupy as cp | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.