@@ -100,9 +100,9 @@ def test_top_k_1d_largest(dtype, n, skip_known_failues_on_cpu):
100100 assert s .values .shape == (k ,)
101101 assert s .values .dtype == inp .dtype
102102 assert s .indices .shape == (k ,)
103- assert dpt .all (s .indices == expected_inds )
104103 assert dpt .all (s .values == dpt .ones (k , dtype = dtype )), s .values
105104 assert dpt .all (s .values == inp [s .indices ]), s .indices
105+ assert dpt .all (s .indices == expected_inds ), (s .indices , expected_inds )
106106
107107
108108def _expected_smallest_inds (inp , n , shift , k ):
@@ -173,9 +173,9 @@ def test_top_k_1d_smallest(dtype, n, skip_known_failues_on_cpu):
173173 assert s .values .shape == (k ,)
174174 assert s .values .dtype == inp .dtype
175175 assert s .indices .shape == (k ,)
176- assert dpt .all (s .indices == expected_inds )
177176 assert dpt .all (s .values == dpt .zeros (k , dtype = dtype )), s .values
178177 assert dpt .all (s .values == inp [s .indices ]), s .indices
178+ assert dpt .all (s .indices == expected_inds ), (s .indices , expected_inds )
179179
180180
181181@pytest .mark .parametrize (
0 commit comments