Skip to content

Commit 437fd4a

Browse files
author
Raghuveer Devulapalli
committed
Add tests for arrays with NAN
1 parent 8bd9c42 commit 437fd4a

File tree

3 files changed

+29
-247
lines changed

3 files changed

+29
-247
lines changed

src/avx512-64bit-argsort.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ void avx512_argselect(double* arr, int64_t *arg, int64_t k, int64_t arrsize)
439439
{
440440
if (arrsize > 1) {
441441
if (has_nan<zmm_vector<double>>(arr, arrsize)) {
442-
std_argselect_withnan(arr, arg, 0, arrsize);
442+
std_argselect_withnan(arr, arg, k, 0, arrsize);
443443
}
444444
else {
445445
argselect_64bit_<zmm_vector<double>>(
@@ -471,7 +471,7 @@ void avx512_argselect(float* arr, int64_t *arg, int64_t k, int64_t arrsize)
471471
{
472472
if (arrsize > 1) {
473473
if (has_nan<ymm_vector<float>>(arr, arrsize)) {
474-
std_argselect_withnan(arr, arg, 0, arrsize);
474+
std_argselect_withnan(arr, arg, k, 0, arrsize);
475475
}
476476
else {
477477
argselect_64bit_<ymm_vector<float>>(

tests/test-argselect.hpp

Lines changed: 20 additions & 239 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@ T std_min_element(std::vector<T> arr, std::vector<int64_t> arg, int64_t left, in
1515
std::vector<int64_t>::iterator res =
1616
std::min_element(arg.begin() + left,
1717
arg.begin() + right,
18-
[arr](int64_t a, int64_t b) -> bool {return arr[a] < arr[b];});
18+
[arr](int64_t a, int64_t b) -> bool {
19+
if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) {return arr[a] < arr[b];}
20+
else if (std::isnan(arr[a])) {return false;}
21+
else {return true;}
22+
});
1923
return arr[*res];
2024
}
2125

@@ -25,7 +29,11 @@ T std_max_element(std::vector<T> arr, std::vector<int64_t> arg, int64_t left, in
2529
std::vector<int64_t>::iterator res =
2630
std::max_element(arg.begin() + left,
2731
arg.begin() + right,
28-
[arr](int64_t a, int64_t b) -> bool {return arr[a] > arr[b];});
32+
[arr](int64_t a, int64_t b) -> bool {
33+
if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) {return arr[a] > arr[b];}
34+
else if (std::isnan(arr[a])) {return true;}
35+
else {return false;}
36+
});
2937
return arr[*res];
3038
}
3139

@@ -34,20 +42,25 @@ TYPED_TEST_P(avx512argselect, test_random)
3442
if (cpu_has_avx512bw()) {
3543
const int arrsize = 1024;
3644
auto arr = get_uniform_rand_array<TypeParam>(arrsize);
45+
std::vector<int64_t> sorted_inx;
46+
if (std::is_floating_point<TypeParam>::value) {
47+
arr[0] = std::numeric_limits<TypeParam>::quiet_NaN();
48+
arr[1] = std::numeric_limits<TypeParam>::quiet_NaN();
49+
}
50+
sorted_inx = std_argsort(arr);
3751
std::vector<int64_t> kth;
38-
for (int64_t ii = 0; ii < arrsize; ++ii) {
52+
for (int64_t ii = 0; ii < arrsize-3; ++ii) {
3953
kth.push_back(ii);
4054
}
41-
std::vector<int64_t> sorted_inx = std_argsort(arr);
4255
for (auto &k : kth) {
4356
std::vector<int64_t> inx
4457
= avx512_argselect<TypeParam>(arr.data(), k, arr.size());
4558
auto true_kth = arr[sorted_inx[k]];
4659
EXPECT_EQ(true_kth, arr[inx[k]]) << "Failed at index k = " << k;
4760
if (k >= 1)
48-
EXPECT_GE(true_kth, std_max_element(arr, inx, 0, k-1));
61+
EXPECT_GE(true_kth, std_max_element(arr, inx, 0, k-1)) << "failed at k = " << k;
4962
if (k != arrsize-1)
50-
EXPECT_LE(true_kth, std_min_element(arr, inx, k+1, arrsize-1));
63+
EXPECT_LE(true_kth, std_min_element(arr, inx, k+1, arrsize-1)) << "failed at k = " << k;
5164
EXPECT_UNIQUE(inx)
5265
}
5366
}
@@ -56,236 +69,4 @@ TYPED_TEST_P(avx512argselect, test_random)
5669
}
5770
}
5871

59-
//TYPED_TEST_P(avx512argselect, test_constant)
60-
//{
61-
// if (cpu_has_avx512bw()) {
62-
// std::vector<int64_t> arrsizes;
63-
// for (int64_t ii = 0; ii <= 1024; ++ii) {
64-
// arrsizes.push_back(ii);
65-
// }
66-
// std::vector<TypeParam> arr;
67-
// for (auto &size : arrsizes) {
68-
// /* constant array */
69-
// auto elem = get_uniform_rand_array<TypeParam>(1)[0];
70-
// for (int64_t jj = 0; jj < size; ++jj) {
71-
// arr.push_back(elem);
72-
// }
73-
// std::vector<int64_t> inx1 = std_argsort(arr);
74-
// std::vector<int64_t> inx2
75-
// = avx512_argsort<TypeParam>(arr.data(), arr.size());
76-
// std::vector<TypeParam> sort1, sort2;
77-
// for (size_t jj = 0; jj < size; ++jj) {
78-
// sort1.push_back(arr[inx1[jj]]);
79-
// sort2.push_back(arr[inx2[jj]]);
80-
// }
81-
// EXPECT_EQ(sort1, sort2) << "Array size =" << size;
82-
// EXPECT_UNIQUE(inx2)
83-
// arr.clear();
84-
// }
85-
// }
86-
// else {
87-
// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
88-
// }
89-
//}
90-
//
91-
//TYPED_TEST_P(avx512argselect, test_small_range)
92-
//{
93-
// if (cpu_has_avx512bw()) {
94-
// std::vector<int64_t> arrsizes;
95-
// for (int64_t ii = 0; ii <= 1024; ++ii) {
96-
// arrsizes.push_back(ii);
97-
// }
98-
// std::vector<TypeParam> arr;
99-
// for (auto &size : arrsizes) {
100-
// /* array with a smaller range of values */
101-
// arr = get_uniform_rand_array<TypeParam>(size, 20, 1);
102-
// std::vector<int64_t> inx1 = std_argsort(arr);
103-
// std::vector<int64_t> inx2
104-
// = avx512_argsort<TypeParam>(arr.data(), arr.size());
105-
// std::vector<TypeParam> sort1, sort2;
106-
// for (size_t jj = 0; jj < size; ++jj) {
107-
// sort1.push_back(arr[inx1[jj]]);
108-
// sort2.push_back(arr[inx2[jj]]);
109-
// }
110-
// EXPECT_EQ(sort1, sort2) << "Array size = " << size;
111-
// EXPECT_UNIQUE(inx2)
112-
// arr.clear();
113-
// }
114-
// }
115-
// else {
116-
// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
117-
// }
118-
//}
119-
//
120-
//TYPED_TEST_P(avx512argselect, test_sorted)
121-
//{
122-
// if (cpu_has_avx512bw()) {
123-
// std::vector<int64_t> arrsizes;
124-
// for (int64_t ii = 0; ii <= 1024; ++ii) {
125-
// arrsizes.push_back(ii);
126-
// }
127-
// std::vector<TypeParam> arr;
128-
// for (auto &size : arrsizes) {
129-
// arr = get_uniform_rand_array<TypeParam>(size);
130-
// std::sort(arr.begin(), arr.end());
131-
// std::vector<int64_t> inx1 = std_argsort(arr);
132-
// std::vector<int64_t> inx2
133-
// = avx512_argsort<TypeParam>(arr.data(), arr.size());
134-
// std::vector<TypeParam> sort1, sort2;
135-
// for (size_t jj = 0; jj < size; ++jj) {
136-
// sort1.push_back(arr[inx1[jj]]);
137-
// sort2.push_back(arr[inx2[jj]]);
138-
// }
139-
// EXPECT_EQ(sort1, sort2) << "Array size =" << size;
140-
// EXPECT_UNIQUE(inx2)
141-
// arr.clear();
142-
// }
143-
// }
144-
// else {
145-
// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
146-
// }
147-
//}
148-
//
149-
//TYPED_TEST_P(avx512argselect, test_reverse)
150-
//{
151-
// if (cpu_has_avx512bw()) {
152-
// std::vector<int64_t> arrsizes;
153-
// for (int64_t ii = 0; ii <= 1024; ++ii) {
154-
// arrsizes.push_back(ii);
155-
// }
156-
// std::vector<TypeParam> arr;
157-
// for (auto &size : arrsizes) {
158-
// arr = get_uniform_rand_array<TypeParam>(size);
159-
// std::sort(arr.begin(), arr.end());
160-
// std::reverse(arr.begin(), arr.end());
161-
// std::vector<int64_t> inx1 = std_argsort(arr);
162-
// std::vector<int64_t> inx2
163-
// = avx512_argsort<TypeParam>(arr.data(), arr.size());
164-
// std::vector<TypeParam> sort1, sort2;
165-
// for (size_t jj = 0; jj < size; ++jj) {
166-
// sort1.push_back(arr[inx1[jj]]);
167-
// sort2.push_back(arr[inx2[jj]]);
168-
// }
169-
// EXPECT_EQ(sort1, sort2) << "Array size =" << size;
170-
// EXPECT_UNIQUE(inx2)
171-
// arr.clear();
172-
// }
173-
// }
174-
// else {
175-
// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
176-
// }
177-
//}
178-
//
179-
//TYPED_TEST_P(avx512argselect, test_array_with_nan)
180-
//{
181-
// if (!cpu_has_avx512bw()) {
182-
// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
183-
// }
184-
// if (!std::is_floating_point<TypeParam>::value) {
185-
// GTEST_SKIP() << "Skipping this test, it is meant for float/double";
186-
// }
187-
// std::vector<int64_t> arrsizes;
188-
// for (int64_t ii = 2; ii <= 1024; ++ii) {
189-
// arrsizes.push_back(ii);
190-
// }
191-
// std::vector<TypeParam> arr;
192-
// for (auto &size : arrsizes) {
193-
// arr = get_uniform_rand_array<TypeParam>(size);
194-
// arr[0] = std::numeric_limits<TypeParam>::quiet_NaN();
195-
// arr[1] = std::numeric_limits<TypeParam>::quiet_NaN();
196-
// std::vector<int64_t> inx
197-
// = avx512_argsort<TypeParam>(arr.data(), arr.size());
198-
// std::vector<TypeParam> sort1;
199-
// for (size_t jj = 0; jj < size; ++jj) {
200-
// sort1.push_back(arr[inx[jj]]);
201-
// }
202-
// if ((!std::isnan(sort1[size - 1])) || (!std::isnan(sort1[size - 2]))) {
203-
// FAIL() << "NAN's aren't sorted to the end";
204-
// }
205-
// if (!std::is_sorted(sort1.begin(), sort1.end() - 2)) {
206-
// FAIL() << "Array isn't sorted";
207-
// }
208-
// EXPECT_UNIQUE(inx)
209-
// arr.clear();
210-
// }
211-
//}
212-
//
213-
//TYPED_TEST_P(avx512argselect, test_max_value_at_end_of_array)
214-
//{
215-
// if (!cpu_has_avx512bw()) {
216-
// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
217-
// }
218-
// std::vector<int64_t> arrsizes;
219-
// for (int64_t ii = 1; ii <= 256; ++ii) {
220-
// arrsizes.push_back(ii);
221-
// }
222-
// std::vector<TypeParam> arr;
223-
// for (auto &size : arrsizes) {
224-
// arr = get_uniform_rand_array<TypeParam>(size);
225-
// if (std::numeric_limits<TypeParam>::has_infinity) {
226-
// arr[size - 1] = std::numeric_limits<TypeParam>::infinity();
227-
// }
228-
// else {
229-
// arr[size - 1] = std::numeric_limits<TypeParam>::max();
230-
// }
231-
// std::vector<int64_t> inx = avx512_argsort(arr.data(), arr.size());
232-
// std::vector<TypeParam> sorted;
233-
// for (size_t jj = 0; jj < size; ++jj) {
234-
// sorted.push_back(arr[inx[jj]]);
235-
// }
236-
// if (!std::is_sorted(sorted.begin(), sorted.end())) {
237-
// EXPECT_TRUE(false) << "Array of size " << size << "is not sorted";
238-
// }
239-
// EXPECT_UNIQUE(inx)
240-
// arr.clear();
241-
// }
242-
//}
243-
//
244-
//TYPED_TEST_P(avx512argselect, test_all_inf_array)
245-
//{
246-
// if (!cpu_has_avx512bw()) {
247-
// GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
248-
// }
249-
// std::vector<int64_t> arrsizes;
250-
// for (int64_t ii = 1; ii <= 256; ++ii) {
251-
// arrsizes.push_back(ii);
252-
// }
253-
// std::vector<TypeParam> arr;
254-
// for (auto &size : arrsizes) {
255-
// arr = get_uniform_rand_array<TypeParam>(size);
256-
// if (std::numeric_limits<TypeParam>::has_infinity) {
257-
// for (int64_t jj = 1; jj <= size; ++jj) {
258-
// if (rand() % 0x1) {
259-
// arr.push_back(std::numeric_limits<TypeParam>::infinity());
260-
// }
261-
// }
262-
// }
263-
// else {
264-
// for (int64_t jj = 1; jj <= size; ++jj) {
265-
// if (rand() % 0x1) {
266-
// arr.push_back(std::numeric_limits<TypeParam>::max());
267-
// }
268-
// }
269-
// }
270-
// std::vector<int64_t> inx = avx512_argsort(arr.data(), arr.size());
271-
// std::vector<TypeParam> sorted;
272-
// for (size_t jj = 0; jj < size; ++jj) {
273-
// sorted.push_back(arr[inx[jj]]);
274-
// }
275-
// if (!std::is_sorted(sorted.begin(), sorted.end())) {
276-
// EXPECT_TRUE(false) << "Array of size " << size << "is not sorted";
277-
// }
278-
// EXPECT_UNIQUE(inx)
279-
// arr.clear();
280-
// }
281-
//}
282-
283-
REGISTER_TYPED_TEST_SUITE_P(avx512argselect,
284-
test_random);
285-
//test_reverse,
286-
//test_constant,
287-
//test_sorted,
288-
//test_small_range,
289-
//test_all_inf_array,
290-
//test_array_with_nan,
291-
//test_max_value_at_end_of_array);
72+
REGISTER_TYPED_TEST_SUITE_P(avx512argselect, test_random);

tests/test-argsort-common.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@
66
#include "avx512-64bit-argsort.hpp"
77

88
template <typename T>
9-
std::vector<int64_t> std_argsort(const std::vector<T> &array)
9+
std::vector<int64_t> std_argsort(const std::vector<T> &arr)
1010
{
11-
std::vector<int64_t> indices(array.size());
11+
std::vector<int64_t> indices(arr.size());
1212
std::iota(indices.begin(), indices.end(), 0);
1313
std::sort(indices.begin(),
1414
indices.end(),
15-
[&array](int left, int right) -> bool {
16-
// sort indices according to corresponding array sizeent
17-
return array[left] < array[right];
18-
});
15+
[&arr](int64_t left, int64_t right) -> bool {
16+
if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) {return arr[left] < arr[right];}
17+
else if (std::isnan(arr[left])) {return false;}
18+
else {return true;}
19+
});
1920

2021
return indices;
2122
}

0 commit comments

Comments
 (0)