Skip to content

Commit 63b0742

Browse files
ilan-golddcherianpre-commit-ci[bot]
authored
fix: Filter out StringDType even when the backing array is not NumpyExtensionArray (#10559)
Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]>
1 parent 85b5a16 commit 63b0742

File tree

12 files changed

+63
-36
lines changed

12 files changed

+63
-36
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ accel = [
3333
"numba>=0.59",
3434
"flox>=0.9",
3535
"opt_einsum",
36+
"numpy<2.3", # numba has not updated yet: https://github.com/numba/numba/issues/10105
3637
]
3738
complete = ["xarray[accel,etc,io,parallel,viz]"]
3839
io = [
@@ -324,6 +325,8 @@ known-first-party = ["xarray"]
324325
[tool.ruff.lint.flake8-tidy-imports]
325326
# Disallow all relative imports.
326327
ban-relative-imports = "all"
328+
[tool.ruff.lint.flake8-tidy-imports.banned-api]
329+
"pandas.api.types.is_extension_array_dtype".msg = "Use xarray.core.utils.is_allowed_extension_array{_dtype} instead. Only use the banend API if the incoming data has already been sanitized by xarray"
327330

328331
[tool.pytest.ini_options]
329332
addopts = [

xarray/computation/ops.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88
from __future__ import annotations
99

1010
import operator
11-
from typing import Literal
11+
from typing import TYPE_CHECKING, Literal
1212

1313
import numpy as np
1414

1515
from xarray.core import dtypes, duck_array_ops
1616

17+
if TYPE_CHECKING:
18+
pass
19+
1720
try:
1821
import bottleneck as bn
1922

@@ -158,8 +161,8 @@ def fillna(data, other, join="left", dataset_join="left"):
158161
)
159162

160163

161-
# Unsure why we get a mypy error here
162-
def where_method(self, cond, other=dtypes.NA): # type: ignore[has-type]
164+
# TODO: type this properly
165+
def where_method(self, cond, other=dtypes.NA): # type: ignore[unused-ignore,has-type]
163166
"""Return elements from `self` or `other` depending on `cond`.
164167
165168
Parameters

xarray/core/dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
import numpy as np
2828
import pandas as pd
29-
from pandas.api.types import is_extension_array_dtype
3029

3130
from xarray.coding.calendar_ops import convert_calendar, interp_calendar
3231
from xarray.coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings
@@ -91,6 +90,7 @@
9190
either_dict_or_kwargs,
9291
emit_user_level_warning,
9392
infix_dims,
93+
is_allowed_extension_array,
9494
is_dict_like,
9595
is_duck_array,
9696
is_duck_dask_array,
@@ -6780,7 +6780,7 @@ def reduce(
67806780
elif (
67816781
# Some reduction functions (e.g. std, var) need to run on variables
67826782
# that don't have the reduce dims: PR5393
6783-
not is_extension_array_dtype(var.dtype)
6783+
not pd.api.types.is_extension_array_dtype(var.dtype) # noqa: TID251
67846784
and (
67856785
not reduce_dims
67866786
or not numeric_only
@@ -7105,12 +7105,12 @@ def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
71057105
non_extension_array_columns = [
71067106
k
71077107
for k in columns_in_order
7108-
if not is_extension_array_dtype(self.variables[k].data)
7108+
if not pd.api.types.is_extension_array_dtype(self.variables[k].data) # noqa: TID251
71097109
]
71107110
extension_array_columns = [
71117111
k
71127112
for k in columns_in_order
7113-
if is_extension_array_dtype(self.variables[k].data)
7113+
if pd.api.types.is_extension_array_dtype(self.variables[k].data) # noqa: TID251
71147114
]
71157115
extension_array_columns_different_index = [
71167116
k
@@ -7302,7 +7302,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
73027302
arrays = []
73037303
extension_arrays = []
73047304
for k, v in dataframe.items():
7305-
if not is_extension_array_dtype(v) or isinstance(
7305+
if not is_allowed_extension_array(v) or isinstance(
73067306
v.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES
73077307
):
73087308
arrays.append((k, np.asarray(v)))

xarray/core/dtypes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any
55

66
import numpy as np
7-
from pandas.api.types import is_extension_array_dtype
7+
import pandas as pd
88

99
from xarray.compat import array_api_compat, npcompat
1010
from xarray.compat.npcompat import HAS_STRING_DTYPE
@@ -213,7 +213,7 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
213213

214214
if isinstance(dtype, np.dtype):
215215
return npcompat.isdtype(dtype, kind)
216-
elif is_extension_array_dtype(dtype):
216+
elif pd.api.types.is_extension_array_dtype(dtype): # noqa: TID251
217217
# we never want to match pandas extension array dtypes
218218
return False
219219
else:

xarray/core/duck_array_ops.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
take,
2424
unravel_index, # noqa: F401
2525
)
26-
from pandas.api.types import is_extension_array_dtype
2726

2827
from xarray.compat import dask_array_compat, dask_array_ops
2928
from xarray.compat.array_api_compat import get_array_namespace
@@ -184,7 +183,7 @@ def isnull(data):
184183
dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool
185184
return full_like(data, dtype=dtype, fill_value=False)
186185
# at this point, array should have dtype=object
187-
elif isinstance(data, np.ndarray) or is_extension_array_dtype(data):
186+
elif isinstance(data, np.ndarray) or pd.api.types.is_extension_array_dtype(data): # noqa: TID251
188187
return pandas_isnull(data)
189188
else:
190189
# Not reachable yet, but intended for use with other duck array
@@ -266,10 +265,12 @@ def asarray(data, xp=np, dtype=None):
266265

267266
def as_shared_dtype(scalars_or_arrays, xp=None):
268267
"""Cast arrays to a shared dtype using xarray's type promotion rules."""
269-
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
270-
extension_array_types = [
271-
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
272-
]
268+
extension_array_types = [
269+
x.dtype
270+
for x in scalars_or_arrays
271+
if pd.api.types.is_extension_array_dtype(x) # noqa: TID251
272+
]
273+
if len(extension_array_types) >= 1:
273274
non_nans = [x for x in scalars_or_arrays if not isna(x)]
274275
if len(extension_array_types) == len(non_nans) and all(
275276
isinstance(x, type(extension_array_types[0])) for x in extension_array_types

xarray/core/extension_array.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
import numpy as np
99
import pandas as pd
1010
from packaging.version import Version
11-
from pandas.api.types import is_extension_array_dtype
1211

1312
from xarray.core.types import DTypeLikeSave, T_ExtensionArray
14-
from xarray.core.utils import NDArrayMixin
13+
from xarray.core.utils import NDArrayMixin, is_allowed_extension_array
1514

1615
HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {}
1716

@@ -100,10 +99,11 @@ def __post_init__(self):
10099
raise TypeError(f"{self.array} is not an pandas ExtensionArray.")
101100
# This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because
102101
# we do support extension arrays from datetime, for example, that need
103-
# duck array support internally via this class.
104-
if isinstance(self.array, pd.arrays.NumpyExtensionArray):
102+
# duck array support internally via this class. These can appear from `DatetimeIndex`
103+
# wrapped by `PandasIndex` internally, for example.
104+
if not is_allowed_extension_array(self.array):
105105
raise TypeError(
106-
"`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally."
106+
f"{self.array.dtype!r} should be converted to a numpy array in `xarray` internally."
107107
)
108108

109109
def __array_function__(self, func, types, args, kwargs):
@@ -126,7 +126,7 @@ def replace_duck_with_extension_array(args) -> list:
126126
if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS:
127127
raise KeyError("Function not registered for pandas extension arrays.")
128128
res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs)
129-
if is_extension_array_dtype(res):
129+
if is_allowed_extension_array(res):
130130
return PandasExtensionArray(res)
131131
return res
132132

@@ -135,7 +135,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
135135

136136
def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]:
137137
item = self.array[key]
138-
if is_extension_array_dtype(item):
138+
if is_allowed_extension_array(item):
139139
return PandasExtensionArray(item)
140140
if np.isscalar(item) or isinstance(key, int):
141141
return PandasExtensionArray(type(self.array)._from_sequence([item])) # type: ignore[call-arg,attr-defined,unused-ignore]

xarray/core/indexes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Frozen,
2424
emit_user_level_warning,
2525
get_valid_numpy_dtype,
26+
is_allowed_extension_array_dtype,
2627
is_dict_like,
2728
is_scalar,
2829
)
@@ -666,9 +667,8 @@ def __init__(
666667

667668
self.index = index
668669
self.dim = dim
669-
670670
if coord_dtype is None:
671-
if pd.api.types.is_extension_array_dtype(index.dtype):
671+
if is_allowed_extension_array_dtype(index.dtype):
672672
cast(pd.api.extensions.ExtensionDtype, index.dtype)
673673
coord_dtype = index.dtype
674674
else:

xarray/core/indexing.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
NDArrayMixin,
2525
either_dict_or_kwargs,
2626
get_valid_numpy_dtype,
27+
is_allowed_extension_array,
28+
is_allowed_extension_array_dtype,
2729
is_duck_array,
2830
is_duck_dask_array,
2931
is_scalar,
@@ -1763,12 +1765,12 @@ def __init__(
17631765
self.array = safe_cast_to_index(array)
17641766

17651767
if dtype is None:
1766-
if pd.api.types.is_extension_array_dtype(array.dtype):
1768+
if is_allowed_extension_array(array):
17671769
cast(pd.api.extensions.ExtensionDtype, array.dtype)
17681770
self._dtype = array.dtype
17691771
else:
17701772
self._dtype = get_valid_numpy_dtype(array)
1771-
elif pd.api.types.is_extension_array_dtype(dtype):
1773+
elif is_allowed_extension_array_dtype(dtype):
17721774
self._dtype = cast(pd.api.extensions.ExtensionDtype, dtype)
17731775
else:
17741776
self._dtype = np.dtype(cast(DTypeLike, dtype))
@@ -1816,10 +1818,7 @@ def get_duck_array(self) -> np.ndarray | PandasExtensionArray:
18161818
# We return an PandasExtensionArray wrapper type that satisfies
18171819
# duck array protocols.
18181820
# `NumpyExtensionArray` is excluded
1819-
if pd.api.types.is_extension_array_dtype(self.array) and not isinstance(
1820-
self.array.array,
1821-
pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined]
1822-
):
1821+
if is_allowed_extension_array(self.array):
18231822
from xarray.core.extension_array import PandasExtensionArray
18241823

18251824
return PandasExtensionArray(self.array.array)
@@ -1916,7 +1915,7 @@ def copy(self, deep: bool = True) -> Self:
19161915

19171916
@property
19181917
def nbytes(self) -> int:
1919-
if pd.api.types.is_extension_array_dtype(self.dtype):
1918+
if is_allowed_extension_array(self.array):
19201919
return self.array.nbytes
19211920

19221921
dtype = self._get_numpy_dtype()

xarray/core/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,20 @@
104104
T = TypeVar("T")
105105

106106

107+
def is_allowed_extension_array_dtype(dtype: Any):
108+
return pd.api.types.is_extension_array_dtype(dtype) and not isinstance( # noqa: TID251
109+
dtype, pd.StringDtype
110+
)
111+
112+
113+
def is_allowed_extension_array(array: Any) -> bool:
114+
return (
115+
hasattr(array, "dtype")
116+
and is_allowed_extension_array_dtype(array.dtype)
117+
and not isinstance(array, pd.arrays.NumpyExtensionArray) # type: ignore[attr-defined]
118+
)
119+
120+
107121
def alias_message(old_name: str, new_name: str) -> str:
108122
return f"{old_name} has been deprecated. Use {new_name} instead."
109123

xarray/core/variable.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
emit_user_level_warning,
4141
ensure_us_time_resolution,
4242
infix_dims,
43+
is_allowed_extension_array,
4344
is_dict_like,
4445
is_duck_array,
4546
is_duck_dask_array,
@@ -198,7 +199,9 @@ def _maybe_wrap_data(data):
198199
return PandasIndexingAdapter(data)
199200
if isinstance(data, UNSUPPORTED_EXTENSION_ARRAY_TYPES):
200201
return data.to_numpy()
201-
if isinstance(data, pd.api.extensions.ExtensionArray):
202+
if isinstance(
203+
data, pd.api.extensions.ExtensionArray
204+
) and is_allowed_extension_array(data):
202205
return PandasExtensionArray(data)
203206
return data
204207

@@ -261,7 +264,8 @@ def convert_non_numpy_type(data):
261264
if isinstance(data, pd.Series | pd.DataFrame):
262265
if (
263266
isinstance(data, pd.Series)
264-
and pd.api.types.is_extension_array_dtype(data)
267+
and is_allowed_extension_array(data.array)
268+
# Some datetime types are not allowed as well as backing Variable types
265269
and not isinstance(data.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES)
266270
):
267271
pandas_data = data.array

0 commit comments

Comments
 (0)