20
20
from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
21
21
22
22
23
+ @pytest .fixture
24
+ def skip_known_failues_on_cpu (request ):
25
+ return request .config .getoption ("--skip-known-top-k-failures-on-cpu" )
26
+
27
+
23
28
def _expected_largest_inds (inp , n , shift , k ):
24
29
"Computed expected top_k indices for mode='largest'"
25
30
assert k < n
@@ -52,10 +57,17 @@ def _expected_largest_inds(inp, n, shift, k):
52
57
return expected_inds
53
58
54
59
60
+ def _skip_if_workaround_is_needed (q , dtype , n , enabled ):
61
+ if enabled :
62
+ dev = q .sycl_device
63
+ if dev .is_cpu and dtype in ["i1" , "i2" ] and n > 128 :
64
+ pytest .skip (reason = "CPU driver bug" )
65
+
66
+
55
67
@pytest .mark .parametrize (
56
68
"dtype" ,
57
69
[
58
- pytest . param ( "i1" , marks = pytest . mark . skip ( reason = "CPU bug" )) ,
70
+ "i1" ,
59
71
"u1" ,
60
72
"i2" ,
61
73
"u2" ,
@@ -71,11 +83,10 @@ def _expected_largest_inds(inp, n, shift, k):
71
83
],
72
84
)
73
85
@pytest .mark .parametrize ("n" , [33 , 43 , 255 , 511 , 1021 , 8193 ])
74
- def test_top_k_1d_largest (dtype , n ):
86
+ def test_top_k_1d_largest (dtype , n , skip_known_failues_on_cpu ):
75
87
q = get_queue_or_skip ()
76
88
skip_if_dtype_not_supported (dtype , q )
77
- if dtype == "i1" :
78
- pytest .skip ()
89
+ _skip_if_workaround_is_needed (q , dtype , n , skip_known_failues_on_cpu )
79
90
80
91
shift , k = 734 , 5
81
92
o = dpt .ones (n , dtype = dtype )
@@ -128,7 +139,7 @@ def _expected_smallest_inds(inp, n, shift, k):
128
139
@pytest .mark .parametrize (
129
140
"dtype" ,
130
141
[
131
- pytest . param ( "i1" , marks = pytest . mark . skip ( reason = "CPU bug" )) ,
142
+ "i1" ,
132
143
"u1" ,
133
144
"i2" ,
134
145
"u2" ,
@@ -144,10 +155,12 @@ def _expected_smallest_inds(inp, n, shift, k):
144
155
],
145
156
)
146
157
@pytest .mark .parametrize ("n" , [37 , 39 , 61 , 255 , 257 , 513 , 1021 , 8193 ])
147
- def test_top_k_1d_smallest (dtype , n ):
158
+ def test_top_k_1d_smallest (dtype , n , skip_known_failues_on_cpu ):
148
159
q = get_queue_or_skip ()
149
160
skip_if_dtype_not_supported (dtype , q )
150
161
162
+ _skip_if_workaround_is_needed (q , dtype , n , skip_known_failues_on_cpu )
163
+
151
164
shift , k = 734 , 5
152
165
o = dpt .ones (n , dtype = dtype )
153
166
z = dpt .zeros (n , dtype = dtype )
@@ -163,3 +176,91 @@ def test_top_k_1d_smallest(dtype, n):
163
176
assert dpt .all (s .indices == expected_inds )
164
177
assert dpt .all (s .values == dpt .zeros (k , dtype = dtype )), s .values
165
178
assert dpt .all (s .values == inp [s .indices ]), s .indices
179
+
180
+
181
+ @pytest .mark .parametrize (
182
+ "dtype" ,
183
+ [
184
+ # skip short types to ensure that m*n can be represented
185
+ # in the type
186
+ "i4" ,
187
+ "u4" ,
188
+ "i8" ,
189
+ "u8" ,
190
+ "f2" ,
191
+ "f4" ,
192
+ "f8" ,
193
+ "c8" ,
194
+ "c16" ,
195
+ ],
196
+ )
197
+ @pytest .mark .parametrize ("n" , [37 , 39 , 61 , 255 , 257 , 513 , 1021 , 8193 ])
198
+ def test_top_k_2d_largest (dtype , n ):
199
+ q = get_queue_or_skip ()
200
+ skip_if_dtype_not_supported (dtype , q )
201
+
202
+ m , k = 8 , 3
203
+ if dtype == "f2" and m * n > 2000 :
204
+ pytest .skip (
205
+ "f2 can not distinguish between large integers used in this test"
206
+ )
207
+
208
+ x = dpt .reshape (dpt .arange (m * n , dtype = dtype ), (m , n ))
209
+
210
+ r = dpt .top_k (x , k , axis = 1 )
211
+
212
+ assert r .values .shape == (m , k )
213
+ assert r .indices .shape == (m , k )
214
+ expected_inds = dpt .reshape (dpt .arange (n , dtype = r .indices .dtype ), (1 , n ))[
215
+ :, - k :
216
+ ]
217
+ assert expected_inds .shape == (1 , k )
218
+ assert dpt .all (
219
+ dpt .sort (r .indices , axis = 1 ) == dpt .sort (expected_inds , axis = 1 )
220
+ ), (r .indices , expected_inds )
221
+ expected_vals = x [:, - k :]
222
+ assert dpt .all (
223
+ dpt .sort (r .values , axis = 1 ) == dpt .sort (expected_vals , axis = 1 )
224
+ )
225
+
226
+
227
+ @pytest .mark .parametrize (
228
+ "dtype" ,
229
+ [
230
+ # skip short types to ensure that m*n can be represented
231
+ # in the type
232
+ "i4" ,
233
+ "u4" ,
234
+ "i8" ,
235
+ "u8" ,
236
+ "f2" ,
237
+ "f4" ,
238
+ "f8" ,
239
+ "c8" ,
240
+ "c16" ,
241
+ ],
242
+ )
243
+ @pytest .mark .parametrize ("n" , [37 , 39 , 61 , 255 , 257 , 513 , 1021 , 8193 ])
244
+ def test_top_k_2d_smallest (dtype , n ):
245
+ q = get_queue_or_skip ()
246
+ skip_if_dtype_not_supported (dtype , q )
247
+
248
+ m , k = 8 , 3
249
+ if dtype == "f2" and m * n > 2000 :
250
+ pytest .skip (
251
+ "f2 can not distinguish between large integers used in this test"
252
+ )
253
+
254
+ x = dpt .reshape (dpt .arange (m * n , dtype = dtype ), (m , n ))
255
+
256
+ r = dpt .top_k (x , k , axis = 1 , mode = "smallest" )
257
+
258
+ assert r .values .shape == (m , k )
259
+ assert r .indices .shape == (m , k )
260
+ expected_inds = dpt .reshape (dpt .arange (n , dtype = r .indices .dtype ), (1 , n ))[
261
+ :, :k
262
+ ]
263
+ assert dpt .all (
264
+ dpt .sort (r .indices , axis = 1 ) == dpt .sort (expected_inds , axis = 1 )
265
+ )
266
+ assert dpt .all (dpt .sort (r .values , axis = 1 ) == dpt .sort (x [:, :k ], axis = 1 ))
0 commit comments