1515 is_jax_array ,
1616 is_writeable_array ,
1717)
18+ from ._utils ._helpers import meta_namespace
1819from ._utils ._typing import Array , Index
1920
2021
@@ -419,9 +420,16 @@ def min(
419420 xp : ModuleType | None = None ,
420421 ) -> Array : # numpydoc ignore=PR01,RT01
421422 """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
423+ # On Dask, this function runs on the chunks, so we need to determine the
424+ # namespace that Dask is wrapping.
425+ # Note that da.minimum _incidentally_ works on numpy, cupy, and sparse
426+ # thanks to all these meta-namespaces implementing the __array_ufunc__
427+ # interface, but there's no guarantee that it will work for other
428+ # wrapped libraries in the future.
422429 xp = array_namespace (self ._x ) if xp is None else xp
430+ mxp = meta_namespace (self ._x , xp = xp )
423431 y = xp .asarray (y )
424- return self ._op (_AtOp .MIN , xp .minimum , xp .minimum , y , copy = copy , xp = xp )
432+ return self ._op (_AtOp .MIN , mxp .minimum , mxp .minimum , y , copy = copy , xp = xp )
425433
426434 def max (
427435 self ,
@@ -431,6 +439,8 @@ def max(
431439 xp : ModuleType | None = None ,
432440 ) -> Array : # numpydoc ignore=PR01,RT01
433441 """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
442+ # See note on min()
434443 xp = array_namespace (self ._x ) if xp is None else xp
444+ mxp = meta_namespace (self ._x , xp = xp )
435445 y = xp .asarray (y )
436- return self ._op (_AtOp .MAX , xp .maximum , xp .maximum , y , copy = copy , xp = xp )
446+ return self ._op (_AtOp .MAX , mxp .maximum , mxp .maximum , y , copy = copy , xp = xp )
0 commit comments