Skip to content

Commit 0e25b13

Browse files
committed
ENH: Dask: sort and argsort
1 parent 8a79994 commit 0e25b13

File tree

3 files changed

+187
-30
lines changed

3 files changed

+187
-30
lines changed

array_api_compat/dask/array/_aliases.py

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3-
from ...common import _aliases
3+
from typing import Callable
4+
5+
from ...common import _aliases, array_namespace
46

57
from ..._internal import get_xp
68

@@ -29,24 +31,27 @@
2931
)
3032

3133
from typing import TYPE_CHECKING
34+
3235
if TYPE_CHECKING:
3336
from typing import Optional, Union
3437

35-
from ...common._typing import Device, Dtype, Array, NestedSequence, SupportsBufferProtocol
38+
from ...common._typing import (
39+
Device,
40+
Dtype,
41+
Array,
42+
NestedSequence,
43+
SupportsBufferProtocol,
44+
)
3645

3746
import dask.array as da
3847

3948
isdtype = get_xp(np)(_aliases.isdtype)
4049
unstack = get_xp(da)(_aliases.unstack)
4150

51+
4252
# da.astype doesn't respect copy=True
4353
def astype(
44-
x: Array,
45-
dtype: Dtype,
46-
/,
47-
*,
48-
copy: bool = True,
49-
device: Optional[Device] = None
54+
x: Array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None
5055
) -> Array:
5156
"""
5257
Array API compatibility wrapper for astype().
@@ -61,8 +66,10 @@ def astype(
6166
x = x.astype(dtype)
6267
return x.copy() if copy else x
6368

69+
6470
# Common aliases
6571

72+
6673
# This arange func is modified from the common one to
6774
# not pass stop/step as keyword arguments, which will cause
6875
# an error with dask
@@ -189,6 +196,7 @@ def asarray(
189196
concatenate as concat,
190197
)
191198

199+
192200
# dask.array.clip does not work unless all three arguments are provided.
193201
# Furthermore, the masking workaround in common._aliases.clip cannot work with
194202
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
@@ -205,8 +213,10 @@ def clip(
205213
See the corresponding documentation in the array library and/or the array API
206214
specification for more details.
207215
"""
216+
208217
def _isscalar(a):
209218
return isinstance(a, (int, float, type(None)))
219+
210220
min_shape = () if _isscalar(min) else min.shape
211221
max_shape = () if _isscalar(max) else max.shape
212222

@@ -228,10 +238,98 @@ def _isscalar(a):
228238

229239
return astype(da.minimum(da.maximum(x, min), max), x.dtype)
230240

231-
# exclude these from all since dask.array has no sorting functions
232-
_da_unsupported = ['sort', 'argsort']
233241

234-
_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
242+
def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], Array]]:
243+
"""
244+
Make sure that Array is not broken into multiple chunks along axis.
245+
246+
Returns
247+
-------
248+
x : Array
249+
The input Array with a single chunk along axis.
250+
restore : Callable[Array, Array]
251+
function to apply to the output to rechunk it back into reasonable chunks
252+
"""
253+
if axis < 0:
254+
axis += x.ndim
255+
if x.numblocks[axis] < 2:
256+
return x, lambda x: x
257+
258+
# Break chunks on other axes in an attempt to keep chunk size low
259+
x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)})
260+
261+
# Rather than reconstructing the original chunks, which can be a
262+
# very expensive affair, just break down oversized chunks without
263+
# incurring in any transfers over the network.
264+
# This has the downside of a risk of overchunking if the array is
265+
# then used in operations against other arrays that match the
266+
# original chunking pattern.
267+
return x, lambda x: x.rechunk()
268+
269+
270+
def sort(
271+
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
272+
) -> Array:
273+
"""
274+
Array API compatibility layer around the lack of sort() in Dask.
275+
276+
Warnings
277+
--------
278+
This function temporarily rechunks the array along `axis` to a single chunk.
279+
This can be extremely inefficient and can lead to out-of-memory errors.
280+
281+
See the corresponding documentation in the array library and/or the array API
282+
specification for more details.
283+
"""
284+
x, restore = _ensure_single_chunk(x, axis)
285+
286+
meta_xp = array_namespace(x._meta)
287+
x = da.map_blocks(
288+
meta_xp.sort,
289+
x,
290+
axis=axis,
291+
meta=x._meta,
292+
dtype=x.dtype,
293+
descending=descending,
294+
stable=stable,
295+
)
296+
297+
return restore(x)
298+
299+
300+
def argsort(
301+
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
302+
) -> Array:
303+
"""
304+
Array API compatibility layer around the lack of argsort() in Dask.
305+
306+
See the corresponding documentation in the array library and/or the array API
307+
specification for more details.
308+
309+
Warnings
310+
--------
311+
This function temporarily rechunks the array along `axis` into a single chunk.
312+
This can be extremely inefficient and can lead to out-of-memory errors.
313+
"""
314+
x, restore = _ensure_single_chunk(x, axis)
315+
316+
meta_xp = array_namespace(x._meta)
317+
dtype = meta_xp.argsort(x._meta).dtype
318+
meta = meta_xp.astype(x._meta, dtype)
319+
x = da.map_blocks(
320+
meta_xp.argsort,
321+
x,
322+
axis=axis,
323+
meta=meta,
324+
dtype=dtype,
325+
descending=descending,
326+
stable=stable,
327+
)
328+
329+
return restore(x)
330+
331+
332+
_common_aliases = _aliases.__all__
235333

236334
__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'acos',
237335
'acosh', 'asin', 'asinh', 'atan', 'atan2',
@@ -242,4 +340,4 @@ def _isscalar(a):
242340
'complex64', 'complex128', 'iinfo', 'finfo',
243341
'can_cast', 'result_type']
244342

245-
_all_ignore = ["get_xp", "da", "np"]
343+
_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"]

dask-xfails.txt

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,13 @@ array_api_tests/test_array_object.py::test_setitem_masking
2323
# Various indexing errors
2424
array_api_tests/test_array_object.py::test_getitem_masking
2525

26-
# asarray(copy=False) is not yet implemented
27-
# copied from numpy xfails, TODO: should this pass with dask?
28-
array_api_tests/test_creation_functions.py::test_asarray_arrays
29-
3026
# zero division error, and typeerror: tuple indices must be integers or slices not tuple
3127
array_api_tests/test_creation_functions.py::test_eye
3228

3329
# finfo(float32).eps returns float32 but should return float
3430
array_api_tests/test_data_type_functions.py::test_finfo[float32]
3531

36-
# out[-1]=dask.aray<getitem ...> but should be some floating number
32+
# out[-1]=dask.array<getitem ...> but should be some floating number
3733
# (I think the test is not forcing the op to be computed?)
3834
array_api_tests/test_creation_functions.py::test_linspace
3935

@@ -48,15 +44,7 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0
4844
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
4945
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
5046

51-
# No sorting in dask
52-
array_api_tests/test_has_names.py::test_has_names[sorting-argsort]
53-
array_api_tests/test_has_names.py::test_has_names[sorting-sort]
54-
array_api_tests/test_sorting_functions.py::test_argsort
55-
array_api_tests/test_sorting_functions.py::test_sort
56-
array_api_tests/test_signatures.py::test_func_signature[argsort]
57-
array_api_tests/test_signatures.py::test_func_signature[sort]
58-
59-
# Array methods and attributes not already on np.ndarray cannot be wrapped
47+
# Array methods and attributes not already on da.Array cannot be wrapped
6048
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
6149
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
6250
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
@@ -76,6 +64,7 @@ array_api_tests/test_set_functions.py::test_unique_values
7664
# fails for ndim > 2
7765
array_api_tests/test_linalg.py::test_svdvals
7866
array_api_tests/test_linalg.py::test_cholesky
67+
7968
# dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :(
8069
array_api_tests/test_linalg.py::test_tensordot
8170

@@ -105,6 +94,8 @@ array_api_tests/test_linalg.py::test_cross
10594
array_api_tests/test_linalg.py::test_det
10695
array_api_tests/test_linalg.py::test_eigh
10796
array_api_tests/test_linalg.py::test_eigvalsh
97+
array_api_tests/test_linalg.py::test_matrix_norm
98+
array_api_tests/test_linalg.py::test_matrix_rank
10899
array_api_tests/test_linalg.py::test_pinv
109100
array_api_tests/test_linalg.py::test_slogdet
110101
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
@@ -115,9 +106,6 @@ array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power]
115106
array_api_tests/test_has_names.py::test_has_names[linalg-pinv]
116107
array_api_tests/test_has_names.py::test_has_names[linalg-slogdet]
117108

118-
array_api_tests/test_linalg.py::test_matrix_norm
119-
array_api_tests/test_linalg.py::test_matrix_rank
120-
121109
# missing mode kw
122110
# https://github.com/dask/dask/issues/10388
123111
array_api_tests/test_linalg.py::test_qr

tests/test_dask.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from contextlib import contextmanager
22

3+
import array_api_strict
34
import dask
45
import numpy as np
56
import pytest
@@ -20,9 +21,10 @@ def assert_no_compute():
2021
Context manager that raises if at any point inside it anything calls compute()
2122
or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc.
2223
"""
24+
2325
def get(dsk, *args, **kwargs):
2426
raise AssertionError("Called compute() or persist()")
25-
27+
2628
with dask.config.set(scheduler=get):
2729
yield
2830

@@ -40,6 +42,7 @@ def test_assert_no_compute():
4042

4143
# Test no_compute for functions that use generic _aliases with xp=np
4244

45+
4346
def test_unary_ops_no_compute(xp):
4447
with assert_no_compute():
4548
a = xp.asarray([1.5, -1.5])
@@ -59,6 +62,7 @@ def test_matmul_tensordot_no_compute(xp):
5962

6063
# Test no_compute for functions that are fully bespoke for dask
6164

65+
6266
def test_asarray_no_compute(xp):
6367
with assert_no_compute():
6468
a = xp.arange(10)
@@ -88,6 +92,14 @@ def test_clip_no_compute(xp):
8892
xp.clip(a, 1, 8)
8993

9094

95+
@pytest.mark.parametrize("chunks", (5, 10))
96+
def test_sort_argsort_nocompute(xp, chunks):
97+
with assert_no_compute():
98+
a = xp.arange(10, chunks=chunks)
99+
xp.sort(a)
100+
xp.argsort(a)
101+
102+
91103
def test_generators_are_lazy(xp):
92104
"""
93105
Test that generator functions are fully lazy, e.g. that
@@ -106,3 +118,62 @@ def test_generators_are_lazy(xp):
106118
xp.ones_like(a)
107119
xp.empty_like(a)
108120
xp.full_like(a, fill_value=123)
121+
122+
123+
@pytest.mark.parametrize("axis", [0, 1])
124+
@pytest.mark.parametrize("func", ["sort", "argsort"])
125+
def test_sort_argsort_chunks(xp, func, axis):
126+
"""Test that sort and argsort are functionally correct when
127+
the array is chunked along the sort axis, e.g. the sort is
128+
not just local to each chunk.
129+
"""
130+
a = da.random.random((10, 10), chunks=(5, 5))
131+
actual = getattr(xp, func)(a, axis=axis)
132+
expect = getattr(np, func)(a.compute(), axis=axis)
133+
np.testing.assert_array_equal(actual, expect)
134+
135+
136+
@pytest.mark.parametrize(
137+
"shape,chunks",
138+
[
139+
# 3 GiB; 128 MiB per chunk; must rechunk before sorting.
140+
# Sort chunks can be 128 MiB each; no need for final rechunk.
141+
((20_000, 20_000), "auto"),
142+
# 3 GiB; 128 MiB per chunk; must rechunk before sorting.
143+
# Must sort on two 1.5 GiB chunks; benefits from final rechunk.
144+
((2, 2**30 * 3 // 16), "auto"),
145+
# 3 GiB; 1.5 GiB per chunk; no need to rechunk before sorting.
146+
# Surely the user must know what they're doing, so don't
147+
# perform the final rechunk.
148+
((2, 2**30 * 3 // 16), (1, -1)),
149+
],
150+
)
151+
@pytest.mark.parametrize("func", ["sort", "argsort"])
152+
def test_sort_argsort_chunk_size(xp, func, shape, chunks):
153+
"""
154+
Test that sort and argsort produce reasonably-sized chunks
155+
in the output array, even if they had to go through a singular
156+
huge one to perform the operation.
157+
"""
158+
a = da.random.random(shape, chunks=chunks)
159+
b = getattr(xp, func)(a)
160+
max_chunk_size = max(b.chunks[0]) * max(b.chunks[1]) * b.dtype.itemsize
161+
assert (
162+
max_chunk_size <= 128 * 1024 * 1024 # 128 MiB
163+
or b.chunks == a.chunks
164+
)
165+
166+
167+
@pytest.mark.parametrize("func", ["sort", "argsort"])
168+
def test_sort_argsort_meta(xp, func):
169+
"""Test meta-namespace other than numpy"""
170+
typ = type(array_api_strict.asarray(0))
171+
a = da.random.random(10)
172+
b = a.map_blocks(array_api_strict.asarray)
173+
assert isinstance(b._meta, typ)
174+
c = getattr(xp, func)(b)
175+
assert isinstance(c._meta, typ)
176+
d = c.compute()
177+
# Note: np.sort(array_api_strict.asarray(0)) would return a numpy array
178+
assert isinstance(d, typ)
179+
np.testing.assert_array_equal(d, getattr(np, func)(a.compute()))

0 commit comments

Comments
 (0)