Skip to content

Commit 9d2cf92

Browse files
author
Raghuveer Devulapalli
committed
Add more tests
1 parent 9bc2e5a commit 9d2cf92

File tree

2 files changed

+80
-7
lines changed

2 files changed

+80
-7
lines changed

tests/test_argsort.cpp

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ std::vector<int64_t> std_argsort(const std::vector<T> &array)
3030
return indices;
3131
}
3232

33-
#define EXPECT_UNIQUE(arr) \
34-
std::sort(arr.begin(), arr.end()); \
35-
std::vector<int64_t> vec(arr.size()); \
36-
std::iota(vec.begin(), vec.end(), 0); \
37-
EXPECT_EQ(arr, vec) << "Indices aren't unique. Array size = " << arr.size();
33+
#define EXPECT_UNIQUE(sorted_arg) \
34+
std::sort(sorted_arg.begin(), sorted_arg.end()); \
35+
std::vector<int64_t> expected_arg(sorted_arg.size()); \
36+
std::iota(expected_arg.begin(), expected_arg.end(), 0); \
37+
EXPECT_EQ(sorted_arg, expected_arg) << "Indices aren't unique. Array size = " << sorted_arg.size();
3838

3939
TYPED_TEST_P(avx512argsort, test_random)
4040
{
@@ -225,7 +225,7 @@ TYPED_TEST_P(avx512argsort, test_max_value_at_end_of_array)
225225
GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
226226
}
227227
std::vector<int64_t> arrsizes;
228-
for (int64_t ii = 1; ii <= 64; ++ii) {
228+
for (int64_t ii = 1; ii <= 256; ++ii) {
229229
arrsizes.push_back(ii);
230230
}
231231
std::vector<TypeParam> arr;
@@ -250,12 +250,52 @@ TYPED_TEST_P(avx512argsort, test_max_value_at_end_of_array)
250250
}
251251
}
252252

253+
TYPED_TEST_P(avx512argsort, test_all_inf_array)
254+
{
255+
if (!cpu_has_avx512bw()) {
256+
GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
257+
}
258+
std::vector<int64_t> arrsizes;
259+
for (int64_t ii = 1; ii <= 256; ++ii) {
260+
arrsizes.push_back(ii);
261+
}
262+
std::vector<TypeParam> arr;
263+
for (auto &size : arrsizes) {
264+
arr = get_uniform_rand_array<TypeParam>(size);
265+
if (std::numeric_limits<TypeParam>::has_infinity) {
266+
for (int64_t jj = 1; jj <= size; ++jj) {
267+
if (rand() % 0x1) {
268+
arr.push_back(std::numeric_limits<TypeParam>::infinity());
269+
}
270+
}
271+
}
272+
else {
273+
for (int64_t jj = 1; jj <= size; ++jj) {
274+
if (rand() % 0x1) {
275+
arr.push_back(std::numeric_limits<TypeParam>::max());
276+
}
277+
}
278+
}
279+
std::vector<int64_t> inx = avx512_argsort(arr.data(), arr.size());
280+
std::vector<TypeParam> sorted;
281+
for (size_t jj = 0; jj < size; ++jj) {
282+
sorted.push_back(arr[inx[jj]]);
283+
}
284+
if (!std::is_sorted(sorted.begin(), sorted.end())) {
285+
EXPECT_TRUE(false) << "Array of size " << size << "is not sorted";
286+
}
287+
EXPECT_UNIQUE(inx)
288+
arr.clear();
289+
}
290+
}
291+
253292
REGISTER_TYPED_TEST_SUITE_P(avx512argsort,
254293
test_random,
255294
test_reverse,
256295
test_constant,
257296
test_sorted,
258297
test_small_range,
298+
test_all_inf_array,
259299
test_array_with_nan,
260300
test_max_value_at_end_of_array);
261301

tests/test_qsort.hpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,41 @@ TYPED_TEST_P(avx512_sort, test_small_range)
128128
GTEST_SKIP() << "Skipping this test, it requires avx512bw";
129129
}
130130
}
131+
132+
TYPED_TEST_P(avx512_sort, test_max_value_at_end_of_array)
133+
{
134+
if (!cpu_has_avx512bw()) {
135+
GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
136+
}
137+
if ((sizeof(TypeParam) == 2) && (!cpu_has_avx512_vbmi2())) {
138+
GTEST_SKIP() << "Skipping this test, it requires avx512_vbmi2";
139+
}
140+
std::vector<int64_t> arrsizes;
141+
for (int64_t ii = 1; ii <= 1024; ++ii) {
142+
arrsizes.push_back(ii);
143+
}
144+
std::vector<TypeParam> arr;
145+
std::vector<TypeParam> sortedarr;
146+
for (auto &size : arrsizes) {
147+
arr = get_uniform_rand_array<TypeParam>(size);
148+
if (std::numeric_limits<TypeParam>::has_infinity) {
149+
arr[size - 1] = std::numeric_limits<TypeParam>::infinity();
150+
}
151+
else {
152+
arr[size - 1] = std::numeric_limits<TypeParam>::max();
153+
}
154+
sortedarr = arr;
155+
avx512_qsort(arr.data(), arr.size());
156+
std::sort(sortedarr.begin(), sortedarr.end());
157+
EXPECT_EQ(sortedarr, arr) << "Array size = " << size;
158+
arr.clear();
159+
sortedarr.clear();
160+
}
161+
}
162+
131163
REGISTER_TYPED_TEST_SUITE_P(avx512_sort,
132164
test_random,
133165
test_reverse,
134166
test_constant,
135-
test_small_range);
167+
test_small_range,
168+
test_max_value_at_end_of_array);

0 commit comments

Comments
 (0)