Skip to content

Commit 001b63f

Browse files
authored
Reduce code duplication for sorting functions (#1914)
* Remove limitations from dpnp.take implementation * Add more test to cover specail cases and increase code coverage * Applied pre-commit hook * Corrected test_over_index * Update docsctrings with resolving typos * Use dpnp.reshape() to change shape and create dpnp array from usm_ndarray result * Remove limitations from dpnp.place implementation * Update relating tests * Roll back changed in dpnp.vander * Remove data sync at the end of function * Update indexing functions * Add missing test scenario * Updated docstring in put_along_axis() and take_along_axis() and rolled back data synchronization * Remove data synchronization for dpnp.put() * Remove data synchronization for dpnp.nonzero() * Remove data synchronization for dpnp.indices() * Remove data synchronization for dpnp.extract() * Update indexing functions * Update sorting functions * Remove data sync from sort() and agrsort() * Remove data sync from dpnp.sort_complex() * Remove data sync in dpnp.get_result_array()
1 parent f19e989 commit 001b63f

File tree

1 file changed

+28
-38
lines changed

1 file changed

+28
-38
lines changed

dpnp/dpnp_iface_sorting.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,28 @@
5757
__all__ = ["argsort", "partition", "sort", "sort_complex"]
5858

5959

60+
def _wrap_sort_argsort(a, _sorting_fn, axis=-1, kind=None, order=None):
61+
"""Wrap a sorting call from dpctl.tensor interface."""
62+
63+
if order is not None:
64+
raise NotImplementedError(
65+
"order keyword argument is only supported with its default value."
66+
)
67+
if kind is not None and kind != "stable":
68+
raise NotImplementedError(
69+
"kind keyword argument can only be None or 'stable'."
70+
)
71+
72+
usm_a = dpnp.get_usm_ndarray(a)
73+
if axis is None:
74+
usm_a = dpt.reshape(usm_a, -1)
75+
axis = -1
76+
77+
axis = normalize_axis_index(axis, ndim=usm_a.ndim)
78+
usm_res = _sorting_fn(usm_a, axis=axis)
79+
return dpnp_array._create_from_usm_ndarray(usm_res)
80+
81+
6082
def argsort(a, axis=-1, kind=None, order=None):
6183
"""
6284
Returns the indices that would sort an array.
@@ -134,24 +156,7 @@ def argsort(a, axis=-1, kind=None, order=None):
134156
135157
"""
136158

137-
if order is not None:
138-
raise NotImplementedError(
139-
"order keyword argument is only supported with its default value."
140-
)
141-
if kind is not None and kind != "stable":
142-
raise NotImplementedError(
143-
"kind keyword argument can only be None or 'stable'."
144-
)
145-
146-
dpnp.check_supported_arrays_type(a)
147-
if axis is None:
148-
a = a.flatten()
149-
axis = -1
150-
151-
axis = normalize_axis_index(axis, ndim=a.ndim)
152-
return dpnp_array._create_from_usm_ndarray(
153-
dpt.argsort(dpnp.get_usm_ndarray(a), axis=axis)
154-
)
159+
return _wrap_sort_argsort(a, dpt.argsort, axis=axis, kind=kind, order=order)
155160

156161

157162
def partition(x1, kth, axis=-1, kind="introselect", order=None):
@@ -246,24 +251,7 @@ def sort(a, axis=-1, kind=None, order=None):
246251
247252
"""
248253

249-
if order is not None:
250-
raise NotImplementedError(
251-
"order keyword argument is only supported with its default value."
252-
)
253-
if kind is not None and kind != "stable":
254-
raise NotImplementedError(
255-
"kind keyword argument can only be None or 'stable'."
256-
)
257-
258-
dpnp.check_supported_arrays_type(a)
259-
if axis is None:
260-
a = a.flatten()
261-
axis = -1
262-
263-
axis = normalize_axis_index(axis, ndim=a.ndim)
264-
return dpnp_array._create_from_usm_ndarray(
265-
dpt.sort(dpnp.get_usm_ndarray(a), axis=axis)
266-
)
254+
return _wrap_sort_argsort(a, dpt.sort, axis=axis, kind=kind, order=order)
267255

268256

269257
def sort_complex(a):
@@ -298,6 +286,8 @@ def sort_complex(a):
298286
b = dpnp.sort(a)
299287
if not dpnp.issubsctype(b.dtype, dpnp.complexfloating):
300288
if b.dtype.char in "bhBH":
301-
return b.astype(dpnp.complex64)
302-
return b.astype(map_dtype_to_device(dpnp.complex128, b.sycl_device))
289+
b_dt = dpnp.complex64
290+
else:
291+
b_dt = map_dtype_to_device(dpnp.complex128, b.sycl_device)
292+
return b.astype(b_dt)
303293
return b

0 commit comments

Comments
 (0)