@@ -270,57 +270,21 @@ inline void argsort_64bit_(type_t *arr,
270
270
argsort_64bit_<vtype>(arr, arg, pivot_index, right, max_iters - 1 );
271
271
}
272
272
273
- template <>
274
- void avx512_argsort<double >(double *arr, int64_t *arg, int64_t arrsize)
275
- {
276
- if (arrsize > 1 ) {
277
- argsort_64bit_<zmm_vector<double >, double >(
278
- arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
279
- }
280
- }
281
-
282
- template <>
283
- std::vector<int64_t > avx512_argsort<double >(double *arr, int64_t arrsize)
284
- {
285
- std::vector<int64_t > indices (arrsize);
286
- std::iota (indices.begin (), indices.end (), 0 );
287
- avx512_argsort<double >(arr, indices.data (), arrsize);
288
- return indices;
289
- }
290
-
291
- template <>
292
- void avx512_argsort<uint64_t >(uint64_t *arr, int64_t *arg, int64_t arrsize)
293
- {
294
- if (arrsize > 1 ) {
295
- argsort_64bit_<zmm_vector<uint64_t >, uint64_t >(
296
- arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
297
- }
298
- }
299
-
300
- template <>
301
- std::vector<int64_t > avx512_argsort<uint64_t >(uint64_t *arr, int64_t arrsize)
302
- {
303
- std::vector<int64_t > indices (arrsize);
304
- std::iota (indices.begin (), indices.end (), 0 );
305
- avx512_argsort<uint64_t >(arr, indices.data (), arrsize);
306
- return indices;
307
- }
308
-
309
- template <>
310
- void avx512_argsort<int64_t >(int64_t *arr, int64_t *arg, int64_t arrsize)
273
+ template <typename T>
274
+ void avx512_argsort (T* arr, int64_t *arg, int64_t arrsize)
311
275
{
312
276
if (arrsize > 1 ) {
313
- argsort_64bit_<zmm_vector<int64_t >, int64_t >(
277
+ argsort_64bit_<zmm_vector<T> >(
314
278
arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
315
279
}
316
280
}
317
281
318
- template <>
319
- std::vector<int64_t > avx512_argsort< int64_t >( int64_t * arr, int64_t arrsize)
282
+ template <typename T >
283
+ std::vector<int64_t > avx512_argsort (T* arr, int64_t arrsize)
320
284
{
321
285
std::vector<int64_t > indices (arrsize);
322
286
std::iota (indices.begin (), indices.end (), 0 );
323
- avx512_argsort<int64_t >(arr, indices.data (), arrsize);
287
+ avx512_argsort<T >(arr, indices.data (), arrsize);
324
288
return indices;
325
289
}
326
290
0 commit comments