Skip to content

Commit 65f14be

Browse files
oleksandr-pavlykndgrigorian
authored andcommitted
Add test file for top_k functionality
1 parent 70c2ef0 commit 65f14be

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

dpctl/tests/test_usm_ndarray_top_k.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2024 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
import dpctl.tensor as dpt
20+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
21+
22+
23+
@pytest.mark.parametrize(
24+
"dtype",
25+
[
26+
"i1",
27+
"u1",
28+
"i2",
29+
"u2",
30+
"i4",
31+
"u4",
32+
"i8",
33+
"u8",
34+
"f2",
35+
"f4",
36+
"f8",
37+
"c8",
38+
"c16",
39+
],
40+
)
41+
@pytest.mark.parametrize("n", [33, 255, 511, 1021, 8193])
42+
def test_topk_1d_largest(dtype, n):
43+
q = get_queue_or_skip()
44+
skip_if_dtype_not_supported(dtype, q)
45+
46+
o = dpt.ones(n, dtype=dtype)
47+
z = dpt.zeros(n, dtype=dtype)
48+
zo = dpt.concat((o, z))
49+
inp = dpt.roll(zo, 734)
50+
k = 5
51+
52+
s = dpt.top_k(inp, k, mode="largest")
53+
assert s.values.shape == (k,)
54+
assert s.values.dtype == inp.dtype
55+
assert s.indices.shape == (k,)
56+
assert dpt.all(s.values == dpt.ones(k, dtype=dtype))
57+
assert dpt.all(s.values == inp[s.indices])
58+
59+
60+
@pytest.mark.parametrize(
61+
"dtype",
62+
[
63+
"i1",
64+
"u1",
65+
"i2",
66+
"u2",
67+
"i4",
68+
"u4",
69+
"i8",
70+
"u8",
71+
"f2",
72+
"f4",
73+
"f8",
74+
"c8",
75+
"c16",
76+
],
77+
)
78+
@pytest.mark.parametrize("n", [33, 255, 257, 513, 1021, 8193])
79+
def test_topk_1d_smallest(dtype, n):
80+
q = get_queue_or_skip()
81+
skip_if_dtype_not_supported(dtype, q)
82+
83+
o = dpt.ones(n, dtype=dtype)
84+
z = dpt.zeros(n, dtype=dtype)
85+
zo = dpt.concat((o, z))
86+
inp = dpt.roll(zo, 734)
87+
k = 5
88+
89+
s = dpt.top_k(inp, k, mode="smallest")
90+
assert s.values.shape == (k,)
91+
assert s.values.dtype == inp.dtype
92+
assert s.indices.shape == (k,)
93+
assert dpt.all(s.values == dpt.zeros(k, dtype=dtype))
94+
assert dpt.all(s.values == inp[s.indices])

0 commit comments

Comments
 (0)