-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Avoid coercing to numpy in as_shared_dtypes
#8714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
c6f4e3a
1467c4c
d9931ef
c067f7d
5092aaa
a884ba8
86e6bf8
630629c
45808d8
e625c67
6833d66
bcd02bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,7 @@ | |
from xarray.core import dask_array_ops, dtypes, nputils, pycompat | ||
from xarray.core.options import OPTIONS | ||
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array | ||
from xarray.core.pycompat import array_type, is_duck_dask_array | ||
from xarray.core.pycompat import is_duck_dask_array, to_duck_array | ||
from xarray.core.utils import is_duck_array, module_available | ||
|
||
# remove once numpy 2.0 is the oldest supported version | ||
|
@@ -219,22 +219,10 @@ def asarray(data, xp=np): | |
|
||
|
||
def as_shared_dtype(scalars_or_arrays, xp=np): | ||
"""Cast a arrays to a shared dtype using xarray's type promotion rules.""" | ||
array_type_cupy = array_type("cupy") | ||
if array_type_cupy and any( | ||
isinstance(x, array_type_cupy) for x in scalars_or_arrays | ||
): | ||
import cupy as cp | ||
|
||
arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] | ||
else: | ||
arrays = [asarray(x, xp=xp) for x in scalars_or_arrays] | ||
# Pass arrays directly instead of dtypes to result_type so scalars | ||
# get handled properly. | ||
# Note that result_type() safely gets the dtype from dask arrays without | ||
# evaluating them. | ||
out_type = dtypes.result_type(*arrays) | ||
return [astype(x, out_type, copy=False) for x in arrays] | ||
"""Cast arrays to a shared dtype using xarray's type promotion rules.""" | ||
duckarrays = [to_duck_array(obj, xp=xp) for obj in scalars_or_arrays] | ||
|
||
out_type = dtypes.result_type(*duckarrays) | ||
return [astype(x, out_type, copy=False) for x in duckarrays] | ||
|
||
|
||
def broadcast_to(array, shape): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
import numpy as np | ||
from packaging.version import Version | ||
|
||
from xarray.core.types import T_DuckArray | ||
from xarray.core.utils import is_duck_array, is_scalar, module_available | ||
|
||
integer_types = (int, np.integer) | ||
|
@@ -86,23 +87,23 @@ def mod_version(mod: ModType) -> Version: | |
return _get_cached_duck_array_module(mod).version | ||
|
||
|
||
def is_dask_collection(x): | ||
def is_dask_collection(x) -> bool: | ||
if module_available("dask"): | ||
from dask.base import is_dask_collection | ||
|
||
return is_dask_collection(x) | ||
return False | ||
|
||
|
||
def is_duck_dask_array(x): | ||
def is_duck_dask_array(x) -> bool: | ||
return is_duck_array(x) and is_dask_collection(x) | ||
|
||
|
||
def is_chunked_array(x) -> bool: | ||
return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks")) | ||
|
||
|
||
def is_0d_dask_array(x): | ||
def is_0d_dask_array(x) -> bool: | ||
return is_duck_dask_array(x) and is_scalar(x) | ||
|
||
|
||
|
@@ -129,12 +130,20 @@ def to_numpy(data) -> np.ndarray: | |
return data | ||
|
||
|
||
def to_duck_array(data): | ||
def to_duck_array(data, xp=np) -> T_DuckArray: | ||
from xarray.core.indexing import ExplicitlyIndexed | ||
|
||
if isinstance(data, ExplicitlyIndexed): | ||
return data.get_duck_array() | ||
elif is_duck_array(data): | ||
return data | ||
else: | ||
return np.asarray(data) | ||
from xarray.core.duck_array_ops import asarray | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use the |
||
|
||
array_type_cupy = array_type("cupy") | ||
if array_type_cupy and any(isinstance(data, array_type_cupy)): | ||
import cupy as cp | ||
|
||
return asarray(data, xp=cp) | ||
else: | ||
return asarray(data, xp=xp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously this
asarray
call would coerce to numpy unnecessarily, when all we really wanted was an array type that we could examine the.dtype
attribute of.