8
8
import numpy as np
9
9
import pandas as pd
10
10
from packaging .version import Version
11
- from pandas .api .types import is_extension_array_dtype
12
11
13
12
from xarray .core .types import DTypeLikeSave , T_ExtensionArray
14
- from xarray .core .utils import NDArrayMixin
13
+ from xarray .core .utils import NDArrayMixin , is_allowed_extension_array
15
14
16
15
HANDLED_EXTENSION_ARRAY_FUNCTIONS : dict [Callable , Callable ] = {}
17
16
@@ -100,10 +99,11 @@ def __post_init__(self):
100
99
raise TypeError (f"{ self .array } is not an pandas ExtensionArray." )
101
100
# This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because
102
101
# we do support extension arrays from datetime, for example, that need
103
- # duck array support internally via this class.
104
- if isinstance (self .array , pd .arrays .NumpyExtensionArray ):
102
+ # duck array support internally via this class. These can appear from `DatetimeIndex`
103
+ # wrapped by `PandasIndex` internally, for example.
104
+ if not is_allowed_extension_array (self .array ):
105
105
raise TypeError (
106
- "`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally."
106
+ f" { self . array . dtype !r } should be converted to a numpy array in `xarray` internally."
107
107
)
108
108
109
109
def __array_function__ (self , func , types , args , kwargs ):
@@ -126,7 +126,7 @@ def replace_duck_with_extension_array(args) -> list:
126
126
if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS :
127
127
raise KeyError ("Function not registered for pandas extension arrays." )
128
128
res = HANDLED_EXTENSION_ARRAY_FUNCTIONS [func ](* args , ** kwargs )
129
- if is_extension_array_dtype (res ):
129
+ if is_allowed_extension_array (res ):
130
130
return PandasExtensionArray (res )
131
131
return res
132
132
@@ -135,7 +135,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
135
135
136
136
def __getitem__ (self , key ) -> PandasExtensionArray [T_ExtensionArray ]:
137
137
item = self .array [key ]
138
- if is_extension_array_dtype (item ):
138
+ if is_allowed_extension_array (item ):
139
139
return PandasExtensionArray (item )
140
140
if np .isscalar (item ) or isinstance (key , int ):
141
141
return PandasExtensionArray (type (self .array )._from_sequence ([item ])) # type: ignore[call-arg,attr-defined,unused-ignore]
0 commit comments