|
13 | 13 | Callable,
|
14 | 14 | Hashable,
|
15 | 15 | )
|
16 |
| -from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, cast, overload |
| 16 | +from typing import TYPE_CHECKING, Any, Literal, cast, overload |
17 | 17 |
|
18 | 18 | import numpy as np
|
19 | 19 |
|
20 | 20 | from xarray.compat.array_api_compat import to_like_array
|
21 |
| -from xarray.computation.apply_ufunc import apply_ufunc |
22 | 21 | from xarray.core import dtypes, duck_array_ops, utils
|
23 | 22 | from xarray.core.common import zeros_like
|
24 | 23 | from xarray.core.duck_array_ops import datetime_to_numeric
|
@@ -467,6 +466,8 @@ def cross(
|
467 | 466 | " dimensions without coordinates must have have a length of 2 or 3"
|
468 | 467 | )
|
469 | 468 |
|
| 469 | + from xarray.computation.apply_ufunc import apply_ufunc |
| 470 | + |
470 | 471 | c = apply_ufunc(
|
471 | 472 | duck_array_ops.cross,
|
472 | 473 | a,
|
@@ -629,6 +630,8 @@ def dot(
|
629 | 630 | # subscripts should be passed to np.einsum as arg, not as kwargs. We need
|
630 | 631 | # to construct a partial function for apply_ufunc to work.
|
631 | 632 | func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs)
|
| 633 | + from xarray.computation.apply_ufunc import apply_ufunc |
| 634 | + |
632 | 635 | result = apply_ufunc(
|
633 | 636 | func,
|
634 | 637 | *arrays,
|
@@ -729,6 +732,8 @@ def where(cond, x, y, keep_attrs=None):
|
729 | 732 | keep_attrs = _get_keep_attrs(default=False)
|
730 | 733 |
|
731 | 734 | # alignment for three arguments is complicated, so don't support it yet
|
| 735 | + from xarray.computation.apply_ufunc import apply_ufunc |
| 736 | + |
732 | 737 | result = apply_ufunc(
|
733 | 738 | duck_array_ops.where,
|
734 | 739 | cond,
|
@@ -951,80 +956,3 @@ def _calc_idxminmax(
|
951 | 956 | res.attrs = indx.attrs
|
952 | 957 |
|
953 | 958 | return res
|
954 |
| - |
955 |
| - |
956 |
| -_T = TypeVar("_T", bound=Union["Dataset", "DataArray"]) |
957 |
| -_U = TypeVar("_U", bound=Union["Dataset", "DataArray"]) |
958 |
| -_V = TypeVar("_V", bound=Union["Dataset", "DataArray"]) |
959 |
| - |
960 |
| - |
961 |
| -@overload |
962 |
| -def unify_chunks(__obj: _T) -> tuple[_T]: ... |
963 |
| - |
964 |
| - |
965 |
| -@overload |
966 |
| -def unify_chunks(__obj1: _T, __obj2: _U) -> tuple[_T, _U]: ... |
967 |
| - |
968 |
| - |
969 |
| -@overload |
970 |
| -def unify_chunks(__obj1: _T, __obj2: _U, __obj3: _V) -> tuple[_T, _U, _V]: ... |
971 |
| - |
972 |
| - |
973 |
| -@overload |
974 |
| -def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, ...]: ... |
975 |
| - |
976 |
| - |
977 |
| -def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, ...]: |
978 |
| - """ |
979 |
| - Given any number of Dataset and/or DataArray objects, returns |
980 |
| - new objects with unified chunk size along all chunked dimensions. |
981 |
| -
|
982 |
| - Returns |
983 |
| - ------- |
984 |
| - unified (DataArray or Dataset) – Tuple of objects with the same type as |
985 |
| - *objects with consistent chunk sizes for all dask-array variables |
986 |
| -
|
987 |
| - See Also |
988 |
| - -------- |
989 |
| - dask.array.core.unify_chunks |
990 |
| - """ |
991 |
| - from xarray.core.dataarray import DataArray |
992 |
| - |
993 |
| - # Convert all objects to datasets |
994 |
| - datasets = [ |
995 |
| - obj._to_temp_dataset() if isinstance(obj, DataArray) else obj.copy() |
996 |
| - for obj in objects |
997 |
| - ] |
998 |
| - |
999 |
| - # Get arguments to pass into dask.array.core.unify_chunks |
1000 |
| - unify_chunks_args = [] |
1001 |
| - sizes: dict[Hashable, int] = {} |
1002 |
| - for ds in datasets: |
1003 |
| - for v in ds._variables.values(): |
1004 |
| - if v.chunks is not None: |
1005 |
| - # Check that sizes match across different datasets |
1006 |
| - for dim, size in v.sizes.items(): |
1007 |
| - try: |
1008 |
| - if sizes[dim] != size: |
1009 |
| - raise ValueError( |
1010 |
| - f"Dimension {dim!r} size mismatch: {sizes[dim]} != {size}" |
1011 |
| - ) |
1012 |
| - except KeyError: |
1013 |
| - sizes[dim] = size |
1014 |
| - unify_chunks_args += [v._data, v._dims] |
1015 |
| - |
1016 |
| - # No dask arrays: Return inputs |
1017 |
| - if not unify_chunks_args: |
1018 |
| - return objects |
1019 |
| - |
1020 |
| - chunkmanager = get_chunked_array_type(*list(unify_chunks_args)) |
1021 |
| - _, chunked_data = chunkmanager.unify_chunks(*unify_chunks_args) |
1022 |
| - chunked_data_iter = iter(chunked_data) |
1023 |
| - out: list[Dataset | DataArray] = [] |
1024 |
| - for obj, ds in zip(objects, datasets, strict=True): |
1025 |
| - for k, v in ds._variables.items(): |
1026 |
| - if v.chunks is not None: |
1027 |
| - ds._variables[k] = v.copy(data=next(chunked_data_iter)) |
1028 |
| - out.append(obj._from_temp_dataset(ds) if isinstance(obj, DataArray) else ds) |
1029 |
| - |
1030 |
| - return tuple(out) |
0 commit comments