Skip to content

Commit bdd0af6

Browse files
author
Raghuveer Devulapalli
committed
Remove template specializations for arg methods
1 parent 34c2798 commit bdd0af6

File tree

2 files changed

+30
-110
lines changed

2 files changed

+30
-110
lines changed

src/avx512-64bit-argsort.hpp

Lines changed: 30 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -348,113 +348,57 @@ static void argselect_64bit_(type_t *arr,
348348
template <typename T>
349349
void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize)
350350
{
351+
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
352+
ymm_vector<T>,
353+
zmm_vector<T>>::type;
351354
if (arrsize > 1) {
352-
argsort_64bit_<zmm_vector<T>>(
353-
arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
354-
}
355-
}
356-
357-
template <>
358-
void avx512_argsort(double *arr, int64_t *arg, int64_t arrsize)
359-
{
360-
if (arrsize > 1) {
361-
if (has_nan<zmm_vector<double>>(arr, arrsize)) {
362-
std_argsort_withnan(arr, arg, 0, arrsize);
355+
if constexpr (std::is_floating_point_v<T>) {
356+
if (has_nan<vectype>(arr, arrsize)) {
357+
std_argsort_withnan(arr, arg, 0, arrsize);
358+
return;
359+
}
363360
}
364-
else {
365-
argsort_64bit_<zmm_vector<double>>(
366-
arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
367-
}
368-
}
369-
}
370-
371-
template <>
372-
void avx512_argsort(int32_t *arr, int64_t *arg, int64_t arrsize)
373-
{
374-
if (arrsize > 1) {
375-
argsort_64bit_<ymm_vector<int32_t>>(
361+
argsort_64bit_<vectype>(
376362
arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
377363
}
378364
}
379365

380-
template <>
381-
void avx512_argsort(uint32_t *arr, int64_t *arg, int64_t arrsize)
382-
{
383-
if (arrsize > 1) {
384-
argsort_64bit_<ymm_vector<uint32_t>>(
385-
arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
386-
}
387-
}
388-
389-
template <>
390-
void avx512_argsort(float *arr, int64_t *arg, int64_t arrsize)
366+
template <typename T>
367+
std::vector<int64_t> avx512_argsort(T *arr, int64_t arrsize)
391368
{
392-
if (arrsize > 1) {
393-
if (has_nan<ymm_vector<float>>(arr, arrsize)) {
394-
std_argsort_withnan(arr, arg, 0, arrsize);
395-
}
396-
else {
397-
argsort_64bit_<ymm_vector<float>>(
398-
arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
399-
}
400-
}
369+
std::vector<int64_t> indices(arrsize);
370+
std::iota(indices.begin(), indices.end(), 0);
371+
avx512_argsort<T>(arr, indices.data(), arrsize);
372+
return indices;
401373
}
402374

403375
/* argselect methods for 32-bit and 64-bit dtypes */
404376
template <typename T>
405377
void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize)
406378
{
407-
if (arrsize > 1) {
408-
argselect_64bit_<zmm_vector<T>>(
409-
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
410-
}
411-
}
379+
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
380+
ymm_vector<T>,
381+
zmm_vector<T>>::type;
412382

413-
template <>
414-
void avx512_argselect(double *arr, int64_t *arg, int64_t k, int64_t arrsize)
415-
{
416383
if (arrsize > 1) {
417-
if (has_nan<zmm_vector<double>>(arr, arrsize)) {
418-
std_argselect_withnan(arr, arg, k, 0, arrsize);
419-
}
420-
else {
421-
argselect_64bit_<zmm_vector<double>>(
422-
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
384+
if constexpr (std::is_floating_point_v<T>) {
385+
if (has_nan<vectype>(arr, arrsize)) {
386+
std_argselect_withnan(arr, arg, k, 0, arrsize);
387+
return;
388+
}
423389
}
424-
}
425-
}
426-
427-
template <>
428-
void avx512_argselect(int32_t *arr, int64_t *arg, int64_t k, int64_t arrsize)
429-
{
430-
if (arrsize > 1) {
431-
argselect_64bit_<ymm_vector<int32_t>>(
390+
argselect_64bit_<vectype>(
432391
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
433392
}
434393
}
435394

436-
template <>
437-
void avx512_argselect(uint32_t *arr, int64_t *arg, int64_t k, int64_t arrsize)
438-
{
439-
if (arrsize > 1) {
440-
argselect_64bit_<ymm_vector<uint32_t>>(
441-
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
442-
}
443-
}
444-
445-
template <>
446-
void avx512_argselect(float *arr, int64_t *arg, int64_t k, int64_t arrsize)
395+
template <typename T>
396+
std::vector<int64_t> avx512_argselect(T *arr, int64_t k, int64_t arrsize)
447397
{
448-
if (arrsize > 1) {
449-
if (has_nan<ymm_vector<float>>(arr, arrsize)) {
450-
std_argselect_withnan(arr, arg, k, 0, arrsize);
451-
}
452-
else {
453-
argselect_64bit_<ymm_vector<float>>(
454-
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
455-
}
456-
}
398+
std::vector<int64_t> indices(arrsize);
399+
std::iota(indices.begin(), indices.end(), 0);
400+
avx512_argselect<T>(arr, indices.data(), k, arrsize);
401+
return indices;
457402
}
458403

459-
460404
#endif // AVX512_ARGSORT_64BIT

src/avx512-common-argsort.h

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,6 @@
1515
using argtype = zmm_vector<int64_t>;
1616
using argzmm_t = typename argtype::zmm_t;
1717

18-
template <typename T>
19-
void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize);
20-
21-
template <typename T>
22-
void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize);
23-
24-
template <typename T>
25-
std::vector<int64_t> avx512_argsort(T *arr, int64_t arrsize)
26-
{
27-
std::vector<int64_t> indices(arrsize);
28-
std::iota(indices.begin(), indices.end(), 0);
29-
avx512_argsort<T>(arr, indices.data(), arrsize);
30-
return indices;
31-
}
32-
33-
template <typename T>
34-
std::vector<int64_t> avx512_argselect(T *arr, int64_t k, int64_t arrsize)
35-
{
36-
std::vector<int64_t> indices(arrsize);
37-
std::iota(indices.begin(), indices.end(), 0);
38-
avx512_argselect<T>(arr, indices.data(), k, arrsize);
39-
return indices;
40-
}
41-
4218
/*
4319
* Parition one ZMM register based on the pivot and returns the index of the
4420
* last element that is less than equal to the pivot.

0 commit comments

Comments
 (0)