Skip to content

Commit 7dcaeff

Browse files
committed
fix argsort dtype
1 parent 1a7316f commit 7dcaeff

File tree

1 file changed

+68
-45
lines changed

1 file changed

+68
-45
lines changed

array_api_compat/dask/array/_aliases.py

Lines changed: 68 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Literal
3+
from typing import Callable
44

55
from ...common import _aliases, array_namespace
66

@@ -31,24 +31,27 @@
3131
)
3232

3333
from typing import TYPE_CHECKING
34+
3435
if TYPE_CHECKING:
3536
from typing import Optional, Union
3637

37-
from ...common._typing import Device, Dtype, Array, NestedSequence, SupportsBufferProtocol
38+
from ...common._typing import (
39+
Device,
40+
Dtype,
41+
Array,
42+
NestedSequence,
43+
SupportsBufferProtocol,
44+
)
3845

3946
import dask.array as da
4047

4148
isdtype = get_xp(np)(_aliases.isdtype)
4249
unstack = get_xp(da)(_aliases.unstack)
4350

51+
4452
# da.astype doesn't respect copy=True
4553
def astype(
46-
x: Array,
47-
dtype: Dtype,
48-
/,
49-
*,
50-
copy: bool = True,
51-
device: Optional[Device] = None
54+
x: Array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None
5255
) -> Array:
5356
"""
5457
Array API compatibility wrapper for astype().
@@ -63,8 +66,10 @@ def astype(
6366
x = x.astype(dtype)
6467
return x.copy() if copy else x
6568

69+
6670
# Common aliases
6771

72+
6873
# This arange func is modified from the common one to
6974
# not pass stop/step as keyword arguments, which will cause
7075
# an error with dask
@@ -191,6 +196,7 @@ def asarray(
191196
concatenate as concat,
192197
)
193198

199+
194200
# dask.array.clip does not work unless all three arguments are provided.
195201
# Furthermore, the masking workaround in common._aliases.clip cannot work with
196202
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
@@ -207,8 +213,10 @@ def clip(
207213
See the corresponding documentation in the array library and/or the array API
208214
specification for more details.
209215
"""
216+
210217
def _isscalar(a):
211218
return isinstance(a, (int, float, type(None)))
219+
212220
min_shape = () if _isscalar(min) else min.shape
213221
max_shape = () if _isscalar(max) else max.shape
214222

@@ -231,6 +239,35 @@ def _isscalar(a):
231239
return astype(da.minimum(da.maximum(x, min), max), x.dtype)
232240

233241

242+
def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], Array]]:
243+
"""
244+
Make sure that Array is not broken into multiple chunks along axis.
245+
246+
Returns
247+
-------
248+
x : Array
249+
The input Array with a single chunk along axis.
250+
restore : Callable[Array, Array]
251+
function to apply to the output to rechunk it back into reasonable chunks
252+
"""
253+
if axis < 0:
254+
axis += x.ndim
255+
if x.numblocks[axis] < 2:
256+
return x, lambda x: x
257+
258+
# Break chunks on other axes in an attempt to keep chunk size low
259+
x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)})
260+
261+
# Rather than reconstructing the original chunks, which can be a
262+
# very expensive affair, just break down oversized chunks without
263+
# incurring in any transfers over the network.
264+
# This has the downside of a risk of overchunking if the array is
265+
# then used in operations against other arrays that match the
266+
# original chunking pattern.
267+
restore = lambda x: x.rechunk()
268+
return x, restore
269+
270+
234271
def sort(
235272
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
236273
) -> Array:
@@ -245,7 +282,20 @@ def sort(
245282
See the corresponding documentation in the array library and/or the array API
246283
specification for more details.
247284
"""
248-
return _sort_argsort("sort", x, axis=axis, descending=descending, stable=stable)
285+
x, restore = _ensure_single_chunk(x, axis)
286+
287+
meta_xp = array_namespace(x._meta)
288+
x = da.map_blocks(
289+
meta_xp.sort,
290+
x,
291+
axis=axis,
292+
meta=x._meta,
293+
dtype=x.dtype,
294+
descending=descending,
295+
stable=stable,
296+
)
297+
298+
return restore(x)
249299

250300

251301
def argsort(
@@ -262,49 +312,22 @@ def argsort(
262312
This function temporarily rechunks the array along `axis` into a single chunk.
263313
This can be extremely inefficient and can lead to out-of-memory errors.
264314
"""
265-
return _sort_argsort("argsort", x, axis=axis, descending=descending, stable=stable)
266-
267-
268-
def _sort_argsort(
269-
func: Literal["sort", "argsort"],
270-
x: Array,
271-
/,
272-
*,
273-
axis: int,
274-
descending: bool,
275-
stable: bool,
276-
) -> Array:
277-
"""
278-
Implementation of sort() and argsort()
315+
x, restore = _ensure_single_chunk(x, axis)
279316

280-
TODO Implement sort and argsort properly in Dask on top of the shuffle subsystem.
281-
"""
282-
if axis < 0:
283-
axis += x.ndim
284-
rechunk = False
285-
if x.numblocks[axis] > 1:
286-
rechunk = True
287-
# Break chunks on other axes in an attempt to keep chunk size low
288-
x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)})
289317
meta_xp = array_namespace(x._meta)
318+
dtype = meta_xp.argsort(x._meta).dtype
319+
meta = meta_xp.astype(x._meta, dtype)
290320
x = da.map_blocks(
291-
getattr(meta_xp, func),
321+
meta_xp.argsort,
292322
x,
293323
axis=axis,
324+
meta=meta,
325+
dtype=dtype,
294326
descending=descending,
295327
stable=stable,
296-
dtype=x.dtype,
297-
meta=x._meta,
298328
)
299-
if rechunk:
300-
# rather than reconstructing the original chunks, which can be a
301-
# very expensive affair, just break down oversized chunks without
302-
# incurring in any transfers over the network.
303-
# This has the downside of a risk of overchunking if the array is
304-
# then used in operations against other arrays that match the
305-
# original chunking pattern.
306-
x = x.rechunk()
307-
return x
329+
330+
return restore(x)
308331

309332

310333
_common_aliases = _aliases.__all__
@@ -318,4 +341,4 @@ def _sort_argsort(
318341
'complex64', 'complex128', 'iinfo', 'finfo',
319342
'can_cast', 'result_type']
320343

321-
_all_ignore = ["Literal", "array_namespace", "get_xp", "da", "np"]
344+
_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"]

0 commit comments

Comments
 (0)