|
57 | 57 | __all__ = ["argsort", "partition", "sort", "sort_complex"]
|
58 | 58 |
|
59 | 59 |
|
| 60 | +def _wrap_sort_argsort(a, _sorting_fn, axis=-1, kind=None, order=None): |
| 61 | + """Wrap a sorting call from dpctl.tensor interface.""" |
| 62 | + |
| 63 | + if order is not None: |
| 64 | + raise NotImplementedError( |
| 65 | + "order keyword argument is only supported with its default value." |
| 66 | + ) |
| 67 | + if kind is not None and kind != "stable": |
| 68 | + raise NotImplementedError( |
| 69 | + "kind keyword argument can only be None or 'stable'." |
| 70 | + ) |
| 71 | + |
| 72 | + usm_a = dpnp.get_usm_ndarray(a) |
| 73 | + if axis is None: |
| 74 | + usm_a = dpt.reshape(usm_a, -1) |
| 75 | + axis = -1 |
| 76 | + |
| 77 | + axis = normalize_axis_index(axis, ndim=usm_a.ndim) |
| 78 | + usm_res = _sorting_fn(usm_a, axis=axis) |
| 79 | + return dpnp_array._create_from_usm_ndarray(usm_res) |
| 80 | + |
| 81 | + |
60 | 82 | def argsort(a, axis=-1, kind=None, order=None):
|
61 | 83 | """
|
62 | 84 | Returns the indices that would sort an array.
|
@@ -134,24 +156,7 @@ def argsort(a, axis=-1, kind=None, order=None):
|
134 | 156 |
|
135 | 157 | """
|
136 | 158 |
|
137 |
| - if order is not None: |
138 |
| - raise NotImplementedError( |
139 |
| - "order keyword argument is only supported with its default value." |
140 |
| - ) |
141 |
| - if kind is not None and kind != "stable": |
142 |
| - raise NotImplementedError( |
143 |
| - "kind keyword argument can only be None or 'stable'." |
144 |
| - ) |
145 |
| - |
146 |
| - dpnp.check_supported_arrays_type(a) |
147 |
| - if axis is None: |
148 |
| - a = a.flatten() |
149 |
| - axis = -1 |
150 |
| - |
151 |
| - axis = normalize_axis_index(axis, ndim=a.ndim) |
152 |
| - return dpnp_array._create_from_usm_ndarray( |
153 |
| - dpt.argsort(dpnp.get_usm_ndarray(a), axis=axis) |
154 |
| - ) |
| 159 | + return _wrap_sort_argsort(a, dpt.argsort, axis=axis, kind=kind, order=order) |
155 | 160 |
|
156 | 161 |
|
157 | 162 | def partition(x1, kth, axis=-1, kind="introselect", order=None):
|
@@ -246,24 +251,7 @@ def sort(a, axis=-1, kind=None, order=None):
|
246 | 251 |
|
247 | 252 | """
|
248 | 253 |
|
249 |
| - if order is not None: |
250 |
| - raise NotImplementedError( |
251 |
| - "order keyword argument is only supported with its default value." |
252 |
| - ) |
253 |
| - if kind is not None and kind != "stable": |
254 |
| - raise NotImplementedError( |
255 |
| - "kind keyword argument can only be None or 'stable'." |
256 |
| - ) |
257 |
| - |
258 |
| - dpnp.check_supported_arrays_type(a) |
259 |
| - if axis is None: |
260 |
| - a = a.flatten() |
261 |
| - axis = -1 |
262 |
| - |
263 |
| - axis = normalize_axis_index(axis, ndim=a.ndim) |
264 |
| - return dpnp_array._create_from_usm_ndarray( |
265 |
| - dpt.sort(dpnp.get_usm_ndarray(a), axis=axis) |
266 |
| - ) |
| 254 | + return _wrap_sort_argsort(a, dpt.sort, axis=axis, kind=kind, order=order) |
267 | 255 |
|
268 | 256 |
|
269 | 257 | def sort_complex(a):
|
@@ -298,6 +286,8 @@ def sort_complex(a):
|
298 | 286 | b = dpnp.sort(a)
|
299 | 287 | if not dpnp.issubsctype(b.dtype, dpnp.complexfloating):
|
300 | 288 | if b.dtype.char in "bhBH":
|
301 |
| - return b.astype(dpnp.complex64) |
302 |
| - return b.astype(map_dtype_to_device(dpnp.complex128, b.sycl_device)) |
| 289 | + b_dt = dpnp.complex64 |
| 290 | + else: |
| 291 | + b_dt = map_dtype_to_device(dpnp.complex128, b.sycl_device) |
| 292 | + return b.astype(b_dt) |
303 | 293 | return b
|
0 commit comments