Skip to content

Commit c2f2423

Browse files
author
Raghuveer Devulapalli
committed
Use templates to write avx512_argsort functions
1 parent be5b1c2 commit c2f2423

File tree

1 file changed

+6
-42
lines changed

1 file changed

+6
-42
lines changed

src/avx512-64bit-argsort.hpp

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -270,57 +270,21 @@ inline void argsort_64bit_(type_t *arr,
270270
argsort_64bit_<vtype>(arr, arg, pivot_index, right, max_iters - 1);
271271
}
272272

273-
template <>
274-
void avx512_argsort<double>(double *arr, int64_t *arg, int64_t arrsize)
275-
{
276-
if (arrsize > 1) {
277-
argsort_64bit_<zmm_vector<double>, double>(
278-
arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
279-
}
280-
}
281-
282-
template <>
283-
std::vector<int64_t> avx512_argsort<double>(double *arr, int64_t arrsize)
284-
{
285-
std::vector<int64_t> indices(arrsize);
286-
std::iota(indices.begin(), indices.end(), 0);
287-
avx512_argsort<double>(arr, indices.data(), arrsize);
288-
return indices;
289-
}
290-
291-
template <>
292-
void avx512_argsort<uint64_t>(uint64_t *arr, int64_t *arg, int64_t arrsize)
293-
{
294-
if (arrsize > 1) {
295-
argsort_64bit_<zmm_vector<uint64_t>, uint64_t>(
296-
arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
297-
}
298-
}
299-
300-
template <>
301-
std::vector<int64_t> avx512_argsort<uint64_t>(uint64_t *arr, int64_t arrsize)
302-
{
303-
std::vector<int64_t> indices(arrsize);
304-
std::iota(indices.begin(), indices.end(), 0);
305-
avx512_argsort<uint64_t>(arr, indices.data(), arrsize);
306-
return indices;
307-
}
308-
309-
template <>
310-
void avx512_argsort<int64_t>(int64_t *arr, int64_t *arg, int64_t arrsize)
273+
template <typename T>
274+
void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
311275
{
312276
if (arrsize > 1) {
313-
argsort_64bit_<zmm_vector<int64_t>, int64_t>(
277+
argsort_64bit_<zmm_vector<T>>(
314278
arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
315279
}
316280
}
317281

318-
template <>
319-
std::vector<int64_t> avx512_argsort<int64_t>(int64_t *arr, int64_t arrsize)
282+
template <typename T>
283+
std::vector<int64_t> avx512_argsort(T* arr, int64_t arrsize)
320284
{
321285
std::vector<int64_t> indices(arrsize);
322286
std::iota(indices.begin(), indices.end(), 0);
323-
avx512_argsort<int64_t>(arr, indices.data(), arrsize);
287+
avx512_argsort<T>(arr, indices.data(), arrsize);
324288
return indices;
325289
}
326290

0 commit comments

Comments
 (0)