Skip to content

Commit 92a628a

Browse files
author
Raghuveer Devulapalli
committed
Move argsort and argselect template func defintion to header
1 parent 51fe743 commit 92a628a

File tree

2 files changed

+16
-20
lines changed

2 files changed

+16
-20
lines changed

src/avx512-64bit-argsort.hpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -427,15 +427,6 @@ void avx512_argsort(float *arr, int64_t *arg, int64_t arrsize)
427427
}
428428
}
429429

430-
template <typename T>
431-
std::vector<int64_t> avx512_argsort(T *arr, int64_t arrsize)
432-
{
433-
std::vector<int64_t> indices(arrsize);
434-
std::iota(indices.begin(), indices.end(), 0);
435-
avx512_argsort<T>(arr, indices.data(), arrsize);
436-
return indices;
437-
}
438-
439430
/* argselect methods for 32-bit and 64-bit dtypes */
440431
template <typename T>
441432
void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize)
@@ -492,13 +483,5 @@ void avx512_argselect(float *arr, int64_t *arg, int64_t k, int64_t arrsize)
492483
}
493484
}
494485

495-
template <typename T>
496-
std::vector<int64_t> avx512_argselect(T *arr, int64_t k, int64_t arrsize)
497-
{
498-
std::vector<int64_t> indices(arrsize);
499-
std::iota(indices.begin(), indices.end(), 0);
500-
avx512_argselect<T>(arr, indices.data(), k, arrsize);
501-
return indices;
502-
}
503486

504487
#endif // AVX512_ARGSORT_64BIT

src/avx512-common-argsort.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,26 @@ template <typename T>
1919
void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize);
2020

2121
template <typename T>
22-
std::vector<int64_t> avx512_argsort(T *arr, int64_t arrsize);
22+
void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize);
2323

2424
template <typename T>
25-
void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize);
25+
std::vector<int64_t> avx512_argsort(T *arr, int64_t arrsize)
26+
{
27+
std::vector<int64_t> indices(arrsize);
28+
std::iota(indices.begin(), indices.end(), 0);
29+
avx512_argsort<T>(arr, indices.data(), arrsize);
30+
return indices;
31+
}
2632

2733
template <typename T>
28-
std::vector<int64_t> avx512_argselect(T *arr, int64_t k, int64_t arrsize);
34+
std::vector<int64_t> avx512_argselect(T *arr, int64_t k, int64_t arrsize)
35+
{
36+
std::vector<int64_t> indices(arrsize);
37+
std::iota(indices.begin(), indices.end(), 0);
38+
avx512_argselect<T>(arr, indices.data(), k, arrsize);
39+
return indices;
40+
}
41+
2942
/*
3043
* Parition one ZMM register based on the pivot and returns the index of the
3144
* last element that is less than equal to the pivot.

0 commit comments

Comments
 (0)