Skip to content
Open
72 changes: 37 additions & 35 deletions xarray/computation/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

import functools
from collections import Counter
from collections.abc import (
Callable,
Hashable,
)
from collections.abc import Callable, Hashable
from typing import TYPE_CHECKING, Any, Literal, cast, overload

import numpy as np
Expand All @@ -23,10 +20,7 @@
from xarray.core.duck_array_ops import datetime_to_numeric
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import Dims, T_DataArray
from xarray.core.utils import (
is_scalar,
parse_dims_as_set,
)
from xarray.core.utils import is_scalar, parse_dims_as_set
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
Expand Down Expand Up @@ -912,14 +906,17 @@ def _calc_idxminmax(
# The dim is not specified and ambiguous. Don't guess.
raise ValueError("Must supply 'dim' argument for multidimensional arrays")

if dim not in array.dims:
raise KeyError(
f"Dimension {dim!r} not found in array dimensions {array.dims!r}"
)
if dim not in array.coords:
raise KeyError(
f"Dimension {dim!r} is not one of the coordinates {tuple(array.coords.keys())}"
)
dims = [dim] if isinstance(dim, str) else list(dim)

for _dim in dims:
if _dim not in array.dims:
raise KeyError(
f"Dimension {_dim!r} not found in array dimensions {array.dims!r}"
)
if _dim not in array.coords:
raise KeyError(
f"Dimension {_dim!r} is not one of the coordinates {tuple(array.coords.keys())}"
)

# These are dtypes with NaN values argmin and argmax can handle
na_dtypes = "cfO"
Expand All @@ -931,25 +928,30 @@ def _calc_idxminmax(

# This will run argmin or argmax.
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)
# Force dictionary format in case of single dim so that we can iterate over it in for loop below
if len(dims) == 1:
indx = {dims[0]: indx}

res = {}
for _dim, _da_idx in zip(dims, indx.values(), strict=False):
# Handle chunked arrays (e.g. dask).
coord = array[_dim]._variable.to_base_variable()
if is_chunked_array(array.data):
chunkmanager = get_chunked_array_type(array.data)
coord_array = chunkmanager.from_array(
array[_dim].data, chunks=((array.sizes[_dim],),)
)
coord = coord.copy(data=coord_array)
else:
coord = coord.copy(data=to_like_array(array[_dim].data, array.data))

# Handle chunked arrays (e.g. dask).
coord = array[dim]._variable.to_base_variable()
if is_chunked_array(array.data):
chunkmanager = get_chunked_array_type(array.data)
coord_array = chunkmanager.from_array(
array[dim].data, chunks=((array.sizes[dim],),)
)
coord = coord.copy(data=coord_array)
else:
coord = coord.copy(data=to_like_array(array[dim].data, array.data))

res = indx._replace(coord[(indx.variable,)]).rename(dim)

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
res = res.where(~allna, fill_value)

# Copy attributes from argmin/argmax, if any
res.attrs = indx.attrs
_res = _da_idx._replace(coord[(_da_idx.variable,)]).rename(_dim)
if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
_res = _res.where(~allna, fill_value)
_res.attrs = _da_idx.attrs
res[_dim] = _res

if len(dims) == 1:
res = res[dims[0]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should have some type stability here so

idmax(dim) -> array; idxmax((dim,)) -> tuple[array]; idxmax((dim0, dim1, ...)) -> tuple[array, ...]

Copy link
Contributor Author

@gcaria gcaria Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the code to match the behavior of DataArray.arg* which returns a dict for both idx*((dim,)) and idx*((dim0, dim1, ...))

Does that seem sensible?

Currently navigating the existing tests for arg* and multiple dims

return res
Loading