11from __future__ import annotations
22
3- from ...common import _aliases
3+ from typing import Callable
4+
5+ from ...common import _aliases , array_namespace
46
57from ..._internal import get_xp
68
2931)
3032
3133from typing import TYPE_CHECKING
34+
3235if TYPE_CHECKING :
3336 from typing import Optional , Union
3437
35- from ...common ._typing import Device , Dtype , Array , NestedSequence , SupportsBufferProtocol
38+ from ...common ._typing import (
39+ Device ,
40+ Dtype ,
41+ Array ,
42+ NestedSequence ,
43+ SupportsBufferProtocol ,
44+ )
3645
3746import dask .array as da
3847
3948isdtype = get_xp (np )(_aliases .isdtype )
4049unstack = get_xp (da )(_aliases .unstack )
4150
51+
4252# da.astype doesn't respect copy=True
4353def astype (
44- x : Array ,
45- dtype : Dtype ,
46- / ,
47- * ,
48- copy : bool = True ,
49- device : Optional [Device ] = None
54+ x : Array , dtype : Dtype , / , * , copy : bool = True , device : Optional [Device ] = None
5055) -> Array :
5156 """
5257 Array API compatibility wrapper for astype().
@@ -61,8 +66,10 @@ def astype(
6166 x = x .astype (dtype )
6267 return x .copy () if copy else x
6368
69+
6470# Common aliases
6571
72+
6673# This arange func is modified from the common one to
6774# not pass stop/step as keyword arguments, which will cause
6875# an error with dask
@@ -189,6 +196,7 @@ def asarray(
189196 concatenate as concat ,
190197)
191198
199+
192200# dask.array.clip does not work unless all three arguments are provided.
193201# Furthermore, the masking workaround in common._aliases.clip cannot work with
194202# dask (meaning uint64 promoting to float64 is going to just be unfixed for
@@ -205,8 +213,10 @@ def clip(
205213 See the corresponding documentation in the array library and/or the array API
206214 specification for more details.
207215 """
216+
208217 def _isscalar (a ):
209218 return isinstance (a , (int , float , type (None )))
219+
210220 min_shape = () if _isscalar (min ) else min .shape
211221 max_shape = () if _isscalar (max ) else max .shape
212222
@@ -228,10 +238,98 @@ def _isscalar(a):
228238
229239 return astype (da .minimum (da .maximum (x , min ), max ), x .dtype )
230240
231- # exclude these from all since dask.array has no sorting functions
232- _da_unsupported = ['sort' , 'argsort' ]
233241
234- _common_aliases = [alias for alias in _aliases .__all__ if alias not in _da_unsupported ]
242+ def _ensure_single_chunk (x : Array , axis : int ) -> tuple [Array , Callable [[Array ], Array ]]:
243+ """
244+ Make sure that Array is not broken into multiple chunks along axis.
245+
246+ Returns
247+ -------
248+ x : Array
249+ The input Array with a single chunk along axis.
250+ restore : Callable[Array, Array]
251+ function to apply to the output to rechunk it back into reasonable chunks
252+ """
253+ if axis < 0 :
254+ axis += x .ndim
255+ if x .numblocks [axis ] < 2 :
256+ return x , lambda x : x
257+
258+ # Break chunks on other axes in an attempt to keep chunk size low
259+ x = x .rechunk ({i : - 1 if i == axis else "auto" for i in range (x .ndim )})
260+
261+ # Rather than reconstructing the original chunks, which can be a
262+ # very expensive affair, just break down oversized chunks without
263+ # incurring in any transfers over the network.
264+ # This has the downside of a risk of overchunking if the array is
265+ # then used in operations against other arrays that match the
266+ # original chunking pattern.
267+ return x , lambda x : x .rechunk ()
268+
269+
270+ def sort (
271+ x : Array , / , * , axis : int = - 1 , descending : bool = False , stable : bool = True
272+ ) -> Array :
273+ """
274+ Array API compatibility layer around the lack of sort() in Dask.
275+
276+ Warnings
277+ --------
278+ This function temporarily rechunks the array along `axis` to a single chunk.
279+ This can be extremely inefficient and can lead to out-of-memory errors.
280+
281+ See the corresponding documentation in the array library and/or the array API
282+ specification for more details.
283+ """
284+ x , restore = _ensure_single_chunk (x , axis )
285+
286+ meta_xp = array_namespace (x ._meta )
287+ x = da .map_blocks (
288+ meta_xp .sort ,
289+ x ,
290+ axis = axis ,
291+ meta = x ._meta ,
292+ dtype = x .dtype ,
293+ descending = descending ,
294+ stable = stable ,
295+ )
296+
297+ return restore (x )
298+
299+
300+ def argsort (
301+ x : Array , / , * , axis : int = - 1 , descending : bool = False , stable : bool = True
302+ ) -> Array :
303+ """
304+ Array API compatibility layer around the lack of argsort() in Dask.
305+
306+ See the corresponding documentation in the array library and/or the array API
307+ specification for more details.
308+
309+ Warnings
310+ --------
311+ This function temporarily rechunks the array along `axis` into a single chunk.
312+ This can be extremely inefficient and can lead to out-of-memory errors.
313+ """
314+ x , restore = _ensure_single_chunk (x , axis )
315+
316+ meta_xp = array_namespace (x ._meta )
317+ dtype = meta_xp .argsort (x ._meta ).dtype
318+ meta = meta_xp .astype (x ._meta , dtype )
319+ x = da .map_blocks (
320+ meta_xp .argsort ,
321+ x ,
322+ axis = axis ,
323+ meta = meta ,
324+ dtype = dtype ,
325+ descending = descending ,
326+ stable = stable ,
327+ )
328+
329+ return restore (x )
330+
331+
332+ _common_aliases = _aliases .__all__
235333
236334__all__ = _common_aliases + ['__array_namespace_info__' , 'asarray' , 'acos' ,
237335 'acosh' , 'asin' , 'asinh' , 'atan' , 'atan2' ,
@@ -242,4 +340,4 @@ def _isscalar(a):
242340 'complex64' , 'complex128' , 'iinfo' , 'finfo' ,
243341 'can_cast' , 'result_type' ]
244342
245- _all_ignore = ["get_xp" , "da" , "np" ]
343+ _all_ignore = ["Callable" , "array_namespace" , " get_xp" , "da" , "np" ]
0 commit comments