1
1
from __future__ import annotations
2
2
3
- from typing import Literal
3
+ from typing import Callable
4
4
5
5
from ...common import _aliases , array_namespace
6
6
31
31
)
32
32
33
33
from typing import TYPE_CHECKING
34
+
34
35
if TYPE_CHECKING :
35
36
from typing import Optional , Union
36
37
37
- from ...common ._typing import Device , Dtype , Array , NestedSequence , SupportsBufferProtocol
38
+ from ...common ._typing import (
39
+ Device ,
40
+ Dtype ,
41
+ Array ,
42
+ NestedSequence ,
43
+ SupportsBufferProtocol ,
44
+ )
38
45
39
46
import dask .array as da
40
47
41
48
isdtype = get_xp (np )(_aliases .isdtype )
42
49
unstack = get_xp (da )(_aliases .unstack )
43
50
51
+
44
52
# da.astype doesn't respect copy=True
45
53
def astype (
46
- x : Array ,
47
- dtype : Dtype ,
48
- / ,
49
- * ,
50
- copy : bool = True ,
51
- device : Optional [Device ] = None
54
+ x : Array , dtype : Dtype , / , * , copy : bool = True , device : Optional [Device ] = None
52
55
) -> Array :
53
56
"""
54
57
Array API compatibility wrapper for astype().
@@ -63,8 +66,10 @@ def astype(
63
66
x = x .astype (dtype )
64
67
return x .copy () if copy else x
65
68
69
+
66
70
# Common aliases
67
71
72
+
68
73
# This arange func is modified from the common one to
69
74
# not pass stop/step as keyword arguments, which will cause
70
75
# an error with dask
@@ -191,6 +196,7 @@ def asarray(
191
196
concatenate as concat ,
192
197
)
193
198
199
+
194
200
# dask.array.clip does not work unless all three arguments are provided.
195
201
# Furthermore, the masking workaround in common._aliases.clip cannot work with
196
202
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
@@ -207,8 +213,10 @@ def clip(
207
213
See the corresponding documentation in the array library and/or the array API
208
214
specification for more details.
209
215
"""
216
+
210
217
def _isscalar (a ):
211
218
return isinstance (a , (int , float , type (None )))
219
+
212
220
min_shape = () if _isscalar (min ) else min .shape
213
221
max_shape = () if _isscalar (max ) else max .shape
214
222
@@ -231,6 +239,35 @@ def _isscalar(a):
231
239
return astype (da .minimum (da .maximum (x , min ), max ), x .dtype )
232
240
233
241
242
+ def _ensure_single_chunk (x : Array , axis : int ) -> tuple [Array , Callable [[Array ], Array ]]:
243
+ """
244
+ Make sure that Array is not broken into multiple chunks along axis.
245
+
246
+ Returns
247
+ -------
248
+ x : Array
249
+ The input Array with a single chunk along axis.
250
+ restore : Callable[Array, Array]
251
+ function to apply to the output to rechunk it back into reasonable chunks
252
+ """
253
+ if axis < 0 :
254
+ axis += x .ndim
255
+ if x .numblocks [axis ] < 2 :
256
+ return x , lambda x : x
257
+
258
+ # Break chunks on other axes in an attempt to keep chunk size low
259
+ x = x .rechunk ({i : - 1 if i == axis else "auto" for i in range (x .ndim )})
260
+
261
+ # Rather than reconstructing the original chunks, which can be a
262
+ # very expensive affair, just break down oversized chunks without
263
+ # incurring in any transfers over the network.
264
+ # This has the downside of a risk of overchunking if the array is
265
+ # then used in operations against other arrays that match the
266
+ # original chunking pattern.
267
+ restore = lambda x : x .rechunk ()
268
+ return x , restore
269
+
270
+
234
271
def sort (
235
272
x : Array , / , * , axis : int = - 1 , descending : bool = False , stable : bool = True
236
273
) -> Array :
@@ -245,7 +282,20 @@ def sort(
245
282
See the corresponding documentation in the array library and/or the array API
246
283
specification for more details.
247
284
"""
248
- return _sort_argsort ("sort" , x , axis = axis , descending = descending , stable = stable )
285
+ x , restore = _ensure_single_chunk (x , axis )
286
+
287
+ meta_xp = array_namespace (x ._meta )
288
+ x = da .map_blocks (
289
+ meta_xp .sort ,
290
+ x ,
291
+ axis = axis ,
292
+ meta = x ._meta ,
293
+ dtype = x .dtype ,
294
+ descending = descending ,
295
+ stable = stable ,
296
+ )
297
+
298
+ return restore (x )
249
299
250
300
251
301
def argsort (
@@ -262,49 +312,22 @@ def argsort(
262
312
This function temporarily rechunks the array along `axis` into a single chunk.
263
313
This can be extremely inefficient and can lead to out-of-memory errors.
264
314
"""
265
- return _sort_argsort ("argsort" , x , axis = axis , descending = descending , stable = stable )
266
-
267
-
268
- def _sort_argsort (
269
- func : Literal ["sort" , "argsort" ],
270
- x : Array ,
271
- / ,
272
- * ,
273
- axis : int ,
274
- descending : bool ,
275
- stable : bool ,
276
- ) -> Array :
277
- """
278
- Implementation of sort() and argsort()
315
+ x , restore = _ensure_single_chunk (x , axis )
279
316
280
- TODO Implement sort and argsort properly in Dask on top of the shuffle subsystem.
281
- """
282
- if axis < 0 :
283
- axis += x .ndim
284
- rechunk = False
285
- if x .numblocks [axis ] > 1 :
286
- rechunk = True
287
- # Break chunks on other axes in an attempt to keep chunk size low
288
- x = x .rechunk ({i : - 1 if i == axis else "auto" for i in range (x .ndim )})
289
317
meta_xp = array_namespace (x ._meta )
318
+ dtype = meta_xp .argsort (x ._meta ).dtype
319
+ meta = meta_xp .astype (x ._meta , dtype )
290
320
x = da .map_blocks (
291
- getattr ( meta_xp , func ) ,
321
+ meta_xp . argsort ,
292
322
x ,
293
323
axis = axis ,
324
+ meta = meta ,
325
+ dtype = dtype ,
294
326
descending = descending ,
295
327
stable = stable ,
296
- dtype = x .dtype ,
297
- meta = x ._meta ,
298
328
)
299
- if rechunk :
300
- # rather than reconstructing the original chunks, which can be a
301
- # very expensive affair, just break down oversized chunks without
302
- # incurring in any transfers over the network.
303
- # This has the downside of a risk of overchunking if the array is
304
- # then used in operations against other arrays that match the
305
- # original chunking pattern.
306
- x = x .rechunk ()
307
- return x
329
+
330
+ return restore (x )
308
331
309
332
310
333
_common_aliases = _aliases .__all__
@@ -318,4 +341,4 @@ def _sort_argsort(
318
341
'complex64' , 'complex128' , 'iinfo' , 'finfo' ,
319
342
'can_cast' , 'result_type' ]
320
343
321
- _all_ignore = ["Literal " , "array_namespace" , "get_xp" , "da" , "np" ]
344
+ _all_ignore = ["Callable " , "array_namespace" , "get_xp" , "da" , "np" ]
0 commit comments