2020from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
2121
2222
23+ def _expected_largest_inds (inp , n , shift , k ):
24+ "Computed expected top_k indices for mode='largest'"
25+ assert k < n
26+ ones_start_id = shift % (2 * n )
27+
28+ alloc_dev = inp .device
29+
30+ if ones_start_id < n :
31+ expected_inds = dpt .arange (
32+ ones_start_id , ones_start_id + k , dtype = "i8" , device = alloc_dev
33+ )
34+ else :
35+ # wrap-around
36+ ones_end_id = (ones_start_id + n ) % (2 * n )
37+ if ones_end_id >= k :
38+ expected_inds = dpt .arange (k , dtype = "i8" , device = alloc_dev )
39+ else :
40+ expected_inds = dpt .concat (
41+ (
42+ dpt .arange (ones_end_id , dtype = "i8" , device = alloc_dev ),
43+ dpt .arange (
44+ ones_start_id ,
45+ ones_start_id + k - ones_end_id ,
46+ dtype = "i8" ,
47+ device = alloc_dev ,
48+ ),
49+ )
50+ )
51+
52+ return expected_inds
53+
54+
2355@pytest .mark .parametrize (
2456 "dtype" ,
2557 [
3870 "c16" ,
3971 ],
4072)
41- @pytest .mark .parametrize ("n" , [33 , 255 , 511 , 1021 , 8193 ])
42- def test_topk_1d_largest (dtype , n ):
73+ @pytest .mark .parametrize ("n" , [33 , 43 , 255 , 511 , 1021 , 8193 ])
74+ def test_top_k_1d_largest (dtype , n ):
4375 q = get_queue_or_skip ()
4476 skip_if_dtype_not_supported (dtype , q )
4577
78+ shift , k = 734 , 5
4679 o = dpt .ones (n , dtype = dtype )
4780 z = dpt .zeros (n , dtype = dtype )
48- zo = dpt .concat ((o , z ))
49- inp = dpt .roll (zo , 734 )
50- k = 5
81+ oz = dpt .concat ((o , z ))
82+ inp = dpt .roll (oz , shift )
83+
84+ expected_inds = _expected_largest_inds (oz , n , shift , k )
5185
5286 s = dpt .top_k (inp , k , mode = "largest" )
5387 assert s .values .shape == (k ,)
5488 assert s .values .dtype == inp .dtype
5589 assert s .indices .shape == (k ,)
56- assert dpt .all (s .values == dpt .ones (k , dtype = dtype ))
57- assert dpt .all (s .values == inp [s .indices ])
90+ assert dpt .all (s .indices == expected_inds )
91+ assert dpt .all (s .values == dpt .ones (k , dtype = dtype )), s .values
92+ assert dpt .all (s .values == inp [s .indices ]), s .indices
93+
94+
95+ def _expected_smallest_inds (inp , n , shift , k ):
96+ "Computed expected top_k indices for mode='smallest'"
97+ assert k < n
98+ zeros_start_id = (n + shift ) % (2 * n )
99+ zeros_end_id = (shift ) % (2 * n )
100+
101+ alloc_dev = inp .device
102+
103+ if zeros_start_id < zeros_end_id :
104+ expected_inds = dpt .arange (
105+ zeros_start_id , zeros_start_id + k , dtype = "i8" , device = alloc_dev
106+ )
107+ else :
108+ if zeros_end_id >= k :
109+ expected_inds = dpt .arange (k , dtype = "i8" , device = alloc_dev )
110+ else :
111+ expected_inds = dpt .concat (
112+ (
113+ dpt .arange (zeros_end_id , dtype = "i8" , device = alloc_dev ),
114+ dpt .arange (
115+ zeros_start_id ,
116+ zeros_start_id + k - zeros_end_id ,
117+ dtype = "i8" ,
118+ device = alloc_dev ,
119+ ),
120+ )
121+ )
122+
123+ return expected_inds
58124
59125
60126@pytest .mark .parametrize (
@@ -75,20 +141,23 @@ def test_topk_1d_largest(dtype, n):
75141 "c16" ,
76142 ],
77143)
78- @pytest .mark .parametrize ("n" , [33 , 255 , 257 , 513 , 1021 , 8193 ])
79- def test_topk_1d_smallest (dtype , n ):
144+ @pytest .mark .parametrize ("n" , [37 , 39 , 61 , 255 , 257 , 513 , 1021 , 8193 ])
145+ def test_top_k_1d_smallest (dtype , n ):
80146 q = get_queue_or_skip ()
81147 skip_if_dtype_not_supported (dtype , q )
82148
149+ shift , k = 734 , 5
83150 o = dpt .ones (n , dtype = dtype )
84151 z = dpt .zeros (n , dtype = dtype )
85- zo = dpt .concat ((o , z ))
86- inp = dpt .roll (zo , 734 )
87- k = 5
152+ oz = dpt .concat ((o , z ))
153+ inp = dpt .roll (oz , shift )
154+
155+ expected_inds = _expected_smallest_inds (oz , n , shift , k )
88156
89157 s = dpt .top_k (inp , k , mode = "smallest" )
90158 assert s .values .shape == (k ,)
91159 assert s .values .dtype == inp .dtype
92160 assert s .indices .shape == (k ,)
93- assert dpt .all (s .values == dpt .zeros (k , dtype = dtype ))
94- assert dpt .all (s .values == inp [s .indices ])
161+ assert dpt .all (s .indices == expected_inds )
162+ assert dpt .all (s .values == dpt .zeros (k , dtype = dtype )), s .values
163+ assert dpt .all (s .values == inp [s .indices ]), s .indices
0 commit comments