@@ -19,13 +19,26 @@ template <typename T>
19
19
void avx512_argsort (T *arr, int64_t *arg, int64_t arrsize);
20
20
21
21
template <typename T>
22
- std::vector< int64_t > avx512_argsort (T *arr, int64_t arrsize);
22
+ void avx512_argselect (T *arr, int64_t *arg, int64_t k , int64_t arrsize);
23
23
24
24
template <typename T>
25
- void avx512_argselect (T *arr, int64_t *arg, int64_t k, int64_t arrsize);
25
+ std::vector<int64_t > avx512_argsort (T *arr, int64_t arrsize)
26
+ {
27
+ std::vector<int64_t > indices (arrsize);
28
+ std::iota (indices.begin (), indices.end (), 0 );
29
+ avx512_argsort<T>(arr, indices.data (), arrsize);
30
+ return indices;
31
+ }
26
32
27
33
template <typename T>
28
- std::vector<int64_t > avx512_argselect (T *arr, int64_t k, int64_t arrsize);
34
+ std::vector<int64_t > avx512_argselect (T *arr, int64_t k, int64_t arrsize)
35
+ {
36
+ std::vector<int64_t > indices (arrsize);
37
+ std::iota (indices.begin (), indices.end (), 0 );
38
+ avx512_argselect<T>(arr, indices.data (), k, arrsize);
39
+ return indices;
40
+ }
41
+
29
42
/*
30
43
* Parition one ZMM register based on the pivot and returns the index of the
31
44
* last element that is less than equal to the pivot.
0 commit comments