|
13 | 13 | from collections.abc import Callable |
14 | 14 | from functools import partial |
15 | 15 | from importlib import import_module |
| 16 | +from typing import Any |
16 | 17 |
|
17 | 18 | import numpy as np |
18 | 19 | import pandas as pd |
|
27 | 28 | from xarray.compat import dask_array_compat, dask_array_ops |
28 | 29 | from xarray.compat.array_api_compat import get_array_namespace |
29 | 30 | from xarray.core import dtypes, nputils |
| 31 | +from xarray.core.extension_array import PandasExtensionArray |
30 | 32 | from xarray.core.options import OPTIONS |
31 | 33 | from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available |
32 | 34 | from xarray.namedarray.parallelcompat import get_chunked_array_type |
@@ -143,6 +145,21 @@ def round(array): |
143 | 145 | around: Callable = round |
144 | 146 |
|
145 | 147 |
|
| 148 | +def isna(data: Any) -> bool: |
| 149 | + """Checks if data is literally np.nan or pd.NA. |
| 150 | +
|
| 151 | + Parameters |
| 152 | + ---------- |
| 153 | + data |
| 154 | + Any python object |
| 155 | +
|
| 156 | + Returns |
| 157 | + ------- |
| 158 | + Whether or not the data is np.nan or pd.NA |
| 159 | + """ |
| 160 | + return data is pd.NA or data is np.nan |
| 161 | + |
| 162 | + |
146 | 163 | def isnull(data): |
147 | 164 | data = asarray(data) |
148 | 165 |
|
@@ -256,13 +273,20 @@ def as_shared_dtype(scalars_or_arrays, xp=None): |
256 | 273 | extension_array_types = [ |
257 | 274 | x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) |
258 | 275 | ] |
259 | | - if len(extension_array_types) == len(scalars_or_arrays) and all( |
| 276 | + non_nans = [x for x in scalars_or_arrays if not isna(x)] |
| 277 | + if len(extension_array_types) == len(non_nans) and all( |
260 | 278 | isinstance(x, type(extension_array_types[0])) for x in extension_array_types |
261 | 279 | ): |
262 | | - return scalars_or_arrays |
| 280 | + return [ |
| 281 | + x |
| 282 | + if not isna(x) |
| 283 | + else PandasExtensionArray( |
| 284 | + type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype) |
| 285 | + ) |
| 286 | + for x in scalars_or_arrays |
| 287 | + ] |
263 | 288 | raise ValueError( |
264 | | - "Cannot cast arrays to shared type, found" |
265 | | - f" array types {[x.dtype for x in scalars_or_arrays]}" |
| 289 | + f"Cannot cast values to shared type, found values: {scalars_or_arrays}" |
266 | 290 | ) |
267 | 291 |
|
268 | 292 | # Avoid calling array_type("cupy") repeatidely in the any check |
|
0 commit comments