Skip to content

Commit fa74e5e

Browse files
committed
Move top_k into _sorting.py
1 parent 1f0e9fd commit fa74e5e

File tree

3 files changed

+166
-187
lines changed

3 files changed

+166
-187
lines changed

dpctl/tensor/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,8 @@
199199
unique_inverse,
200200
unique_values,
201201
)
202-
from ._sorting import argsort, sort
202+
from ._sorting import argsort, sort, top_k
203203
from ._testing import allclose
204-
from ._topk import top_k
205204
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
206205

207206
__all__ = [

dpctl/tensor/_sorting.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import operator
18+
from typing import NamedTuple
19+
1720
import dpctl.tensor as dpt
1821
import dpctl.tensor._tensor_impl as ti
1922
import dpctl.utils as du
@@ -24,6 +27,7 @@
2427
_argsort_descending,
2528
_sort_ascending,
2629
_sort_descending,
30+
_topk,
2731
)
2832
from ._tensor_sorting_radix_impl import (
2933
_radix_argsort_ascending,
@@ -267,3 +271,164 @@ def argsort(x, axis=-1, descending=False, stable=True, kind=None):
267271
inv_perm = sorted(range(nd), key=lambda d: perm[d])
268272
res = dpt.permute_dims(res, inv_perm)
269273
return res
274+
275+
276+
def _get_top_k_largest(mode):
277+
modes = {"largest": True, "smallest": False}
278+
try:
279+
return modes[mode]
280+
except KeyError:
281+
raise ValueError(
282+
f"`mode` must be `largest` or `smallest`. Got `{mode}`."
283+
)
284+
285+
286+
class TopKResult(NamedTuple):
287+
values: dpt.usm_ndarray
288+
indices: dpt.usm_ndarray
289+
290+
291+
def top_k(x, k, /, *, axis=None, mode="largest"):
292+
"""top_k(x, k, axis=None, mode="largest")
293+
294+
Returns the `k` largest or smallest values and their indices in the input
295+
array `x` along the specified axis `axis`.
296+
297+
Args:
298+
x (usm_ndarray):
299+
input array.
300+
k (int):
301+
number of elements to find. Must be a positive integer value.
302+
axis (Optional[int]):
303+
axis along which to search. If `None`, the search will be performed
304+
over the flattened array. Default: ``None``.
305+
mode (Literal["largest", "smallest"]):
306+
search mode. Must be one of the following modes:
307+
308+
- `"largest"`: return the `k` largest elements.
309+
- `"smallest"`: return the `k` smallest elements.
310+
311+
Default: `"largest"`.
312+
313+
Returns:
314+
tuple[usm_ndarray, usm_ndarray]
315+
a namedtuple `(values, indices)` whose
316+
317+
* first element `values` will be an array containing the `k`
318+
largest or smallest elements of `x`. The array has the same data
319+
type as `x`. If `axis` was `None`, `values` will be a
320+
one-dimensional array with shape `(k,)` and otherwise, `values`
321+
will have shape `x.shape[:axis] + (k,) + x.shape[axis+1:]`
322+
* second element `indices` will be an array containing indices of
323+
`x` that result in `values`. The array will have the same shape
324+
as `values` and will have the default array index data type.
325+
"""
326+
largest = _get_top_k_largest(mode)
327+
if not isinstance(x, dpt.usm_ndarray):
328+
raise TypeError(
329+
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
330+
)
331+
332+
k = operator.index(k)
333+
if k < 0:
334+
raise ValueError("`k` must be a positive integer value")
335+
336+
nd = x.ndim
337+
if axis is None:
338+
sz = x.size
339+
if nd == 0:
340+
return TopKResult(
341+
dpt.copy(x, order="C"),
342+
dpt.zeros_like(
343+
x, dtype=ti.default_device_index_type(x.sycl_queue)
344+
),
345+
)
346+
arr = x
347+
n_search_dims = None
348+
res_sh = k
349+
else:
350+
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
351+
sz = x.shape[axis]
352+
a1 = axis + 1
353+
if a1 == nd:
354+
perm = list(range(nd))
355+
arr = x
356+
else:
357+
perm = [i for i in range(nd) if i != axis] + [
358+
axis,
359+
]
360+
arr = dpt.permute_dims(x, perm)
361+
n_search_dims = 1
362+
res_sh = arr.shape[: nd - 1] + (k,)
363+
364+
if k > sz:
365+
raise ValueError(f"`k`={k} is out of bounds {sz}")
366+
367+
exec_q = x.sycl_queue
368+
_manager = du.SequentialOrderManager[exec_q]
369+
dep_evs = _manager.submitted_events
370+
371+
res_usm_type = arr.usm_type
372+
if arr.flags.c_contiguous:
373+
vals = dpt.empty(
374+
res_sh,
375+
dtype=arr.dtype,
376+
usm_type=res_usm_type,
377+
order="C",
378+
sycl_queue=exec_q,
379+
)
380+
inds = dpt.empty(
381+
res_sh,
382+
dtype=ti.default_device_index_type(exec_q),
383+
usm_type=res_usm_type,
384+
order="C",
385+
sycl_queue=exec_q,
386+
)
387+
ht_ev, impl_ev = _topk(
388+
src=arr,
389+
trailing_dims_to_search=n_search_dims,
390+
k=k,
391+
largest=largest,
392+
vals=vals,
393+
inds=inds,
394+
sycl_queue=exec_q,
395+
depends=dep_evs,
396+
)
397+
_manager.add_event_pair(ht_ev, impl_ev)
398+
else:
399+
tmp = dpt.empty_like(arr, order="C")
400+
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
401+
src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
402+
)
403+
_manager.add_event_pair(ht_ev, copy_ev)
404+
vals = dpt.empty(
405+
res_sh,
406+
dtype=arr.dtype,
407+
usm_type=res_usm_type,
408+
order="C",
409+
sycl_queue=exec_q,
410+
)
411+
inds = dpt.empty(
412+
res_sh,
413+
dtype=ti.default_device_index_type(exec_q),
414+
usm_type=res_usm_type,
415+
order="C",
416+
sycl_queue=exec_q,
417+
)
418+
ht_ev, impl_ev = _topk(
419+
src=tmp,
420+
trailing_dims_to_search=n_search_dims,
421+
k=k,
422+
largest=largest,
423+
vals=vals,
424+
inds=inds,
425+
sycl_queue=exec_q,
426+
depends=[copy_ev],
427+
)
428+
_manager.add_event_pair(ht_ev, impl_ev)
429+
if axis is not None and a1 != nd:
430+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
431+
vals = dpt.permute_dims(vals, inv_perm)
432+
inds = dpt.permute_dims(inds, inv_perm)
433+
434+
return TopKResult(vals, inds)

dpctl/tensor/_topk.py

Lines changed: 0 additions & 185 deletions
This file was deleted.

0 commit comments

Comments
 (0)