1
1
from __future__ import annotations
2
2
3
- from ...common import _aliases
3
+ from typing import Callable
4
+
5
+ from ...common import _aliases , array_namespace
4
6
5
7
from ..._internal import get_xp
6
8
29
31
)
30
32
31
33
from typing import TYPE_CHECKING
34
+
32
35
if TYPE_CHECKING :
33
36
from typing import Optional , Union
34
37
35
- from ...common ._typing import Device , Dtype , Array , NestedSequence , SupportsBufferProtocol
38
+ from ...common ._typing import (
39
+ Device ,
40
+ Dtype ,
41
+ Array ,
42
+ NestedSequence ,
43
+ SupportsBufferProtocol ,
44
+ )
36
45
37
46
import dask .array as da
38
47
39
48
isdtype = get_xp (np )(_aliases .isdtype )
40
49
unstack = get_xp (da )(_aliases .unstack )
41
50
51
+
42
52
# da.astype doesn't respect copy=True
43
53
def astype (
44
- x : Array ,
45
- dtype : Dtype ,
46
- / ,
47
- * ,
48
- copy : bool = True ,
49
- device : Optional [Device ] = None
54
+ x : Array , dtype : Dtype , / , * , copy : bool = True , device : Optional [Device ] = None
50
55
) -> Array :
51
56
"""
52
57
Array API compatibility wrapper for astype().
@@ -61,8 +66,10 @@ def astype(
61
66
x = x .astype (dtype )
62
67
return x .copy () if copy else x
63
68
69
+
64
70
# Common aliases
65
71
72
+
66
73
# This arange func is modified from the common one to
67
74
# not pass stop/step as keyword arguments, which will cause
68
75
# an error with dask
@@ -189,6 +196,7 @@ def asarray(
189
196
concatenate as concat ,
190
197
)
191
198
199
+
192
200
# dask.array.clip does not work unless all three arguments are provided.
193
201
# Furthermore, the masking workaround in common._aliases.clip cannot work with
194
202
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
@@ -205,8 +213,10 @@ def clip(
205
213
See the corresponding documentation in the array library and/or the array API
206
214
specification for more details.
207
215
"""
216
+
208
217
def _isscalar (a ):
209
218
return isinstance (a , (int , float , type (None )))
219
+
210
220
min_shape = () if _isscalar (min ) else min .shape
211
221
max_shape = () if _isscalar (max ) else max .shape
212
222
@@ -228,10 +238,98 @@ def _isscalar(a):
228
238
229
239
return astype (da .minimum (da .maximum (x , min ), max ), x .dtype )
230
240
231
- # exclude these from all since dask.array has no sorting functions
232
- _da_unsupported = ['sort' , 'argsort' ]
233
241
234
- _common_aliases = [alias for alias in _aliases .__all__ if alias not in _da_unsupported ]
242
+ def _ensure_single_chunk (x : Array , axis : int ) -> tuple [Array , Callable [[Array ], Array ]]:
243
+ """
244
+ Make sure that Array is not broken into multiple chunks along axis.
245
+
246
+ Returns
247
+ -------
248
+ x : Array
249
+ The input Array with a single chunk along axis.
250
+ restore : Callable[Array, Array]
251
+ function to apply to the output to rechunk it back into reasonable chunks
252
+ """
253
+ if axis < 0 :
254
+ axis += x .ndim
255
+ if x .numblocks [axis ] < 2 :
256
+ return x , lambda x : x
257
+
258
+ # Break chunks on other axes in an attempt to keep chunk size low
259
+ x = x .rechunk ({i : - 1 if i == axis else "auto" for i in range (x .ndim )})
260
+
261
+ # Rather than reconstructing the original chunks, which can be a
262
+ # very expensive affair, just break down oversized chunks without
263
+ # incurring in any transfers over the network.
264
+ # This has the downside of a risk of overchunking if the array is
265
+ # then used in operations against other arrays that match the
266
+ # original chunking pattern.
267
+ return x , lambda x : x .rechunk ()
268
+
269
+
270
+ def sort (
271
+ x : Array , / , * , axis : int = - 1 , descending : bool = False , stable : bool = True
272
+ ) -> Array :
273
+ """
274
+ Array API compatibility layer around the lack of sort() in Dask.
275
+
276
+ Warnings
277
+ --------
278
+ This function temporarily rechunks the array along `axis` to a single chunk.
279
+ This can be extremely inefficient and can lead to out-of-memory errors.
280
+
281
+ See the corresponding documentation in the array library and/or the array API
282
+ specification for more details.
283
+ """
284
+ x , restore = _ensure_single_chunk (x , axis )
285
+
286
+ meta_xp = array_namespace (x ._meta )
287
+ x = da .map_blocks (
288
+ meta_xp .sort ,
289
+ x ,
290
+ axis = axis ,
291
+ meta = x ._meta ,
292
+ dtype = x .dtype ,
293
+ descending = descending ,
294
+ stable = stable ,
295
+ )
296
+
297
+ return restore (x )
298
+
299
+
300
+ def argsort (
301
+ x : Array , / , * , axis : int = - 1 , descending : bool = False , stable : bool = True
302
+ ) -> Array :
303
+ """
304
+ Array API compatibility layer around the lack of argsort() in Dask.
305
+
306
+ See the corresponding documentation in the array library and/or the array API
307
+ specification for more details.
308
+
309
+ Warnings
310
+ --------
311
+ This function temporarily rechunks the array along `axis` into a single chunk.
312
+ This can be extremely inefficient and can lead to out-of-memory errors.
313
+ """
314
+ x , restore = _ensure_single_chunk (x , axis )
315
+
316
+ meta_xp = array_namespace (x ._meta )
317
+ dtype = meta_xp .argsort (x ._meta ).dtype
318
+ meta = meta_xp .astype (x ._meta , dtype )
319
+ x = da .map_blocks (
320
+ meta_xp .argsort ,
321
+ x ,
322
+ axis = axis ,
323
+ meta = meta ,
324
+ dtype = dtype ,
325
+ descending = descending ,
326
+ stable = stable ,
327
+ )
328
+
329
+ return restore (x )
330
+
331
+
332
+ _common_aliases = _aliases .__all__
235
333
236
334
__all__ = _common_aliases + ['__array_namespace_info__' , 'asarray' , 'acos' ,
237
335
'acosh' , 'asin' , 'asinh' , 'atan' , 'atan2' ,
@@ -242,4 +340,4 @@ def _isscalar(a):
242
340
'complex64' , 'complex128' , 'iinfo' , 'finfo' ,
243
341
'can_cast' , 'result_type' ]
244
342
245
- _all_ignore = ["get_xp" , "da" , "np" ]
343
+ _all_ignore = ["Callable" , "array_namespace" , " get_xp" , "da" , "np" ]
0 commit comments