Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/c-cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ jobs:

- name: Run test suite on SPR
run: sde -spr -- ./builddir/testexe
- name: Run ICL fp16 tests
# Note: This filters for the _Float16 tests based on the number assigned to it, which could change in the future
run: sde -icx -- ./builddir/testexe --gtest_filter="*/simdsort/2*"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The np-multiarray-tgl job does test the float16 portion of the code on a TGL, but this is fine too.


SKX-SKL-openmp:

Expand Down
27 changes: 27 additions & 0 deletions lib/x86simdsort-icl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,31 @@ namespace avx512 {
x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending);
}
} // namespace avx512
namespace fp16_icl {
#ifdef __FLT16_MAX__
template <>
void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending)
{
x86simdsortStatic::qsort(arr, size, hasnan, descending);
}
template <>
void qselect(_Float16 *arr,
size_t k,
size_t arrsize,
bool hasnan,
bool descending)
{
x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending);
}
template <>
void partial_qsort(_Float16 *arr,
size_t k,
size_t arrsize,
bool hasnan,
bool descending)
{
x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending);
}
#endif
} // namespace fp16_icl
} // namespace xss
213 changes: 54 additions & 159 deletions lib/x86simdsort-internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,165 +4,60 @@
#include <stdint.h>
#include <vector>

#define DECLAREALLFUNCS(name) \
namespace name { \
template <typename T> \
XSS_HIDE_SYMBOL void qsort(T *arr, \
size_t arrsize, \
bool hasnan = false, \
bool descending = false); \
template <typename T1, typename T2> \
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key, \
T2 *val, \
size_t arrsize, \
bool hasnan = false, \
bool descending = false); \
template <typename T> \
XSS_HIDE_SYMBOL void qselect(T *arr, \
size_t k, \
size_t arrsize, \
bool hasnan = false, \
bool descending = false); \
template <typename T1, typename T2> \
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, \
T2 *val, \
size_t k, \
size_t arrsize, \
bool hasnan = false, \
bool descending = false); \
template <typename T> \
XSS_HIDE_SYMBOL void partial_qsort(T *arr, \
size_t k, \
size_t arrsize, \
bool hasnan = false, \
bool descending = false); \
template <typename T1, typename T2> \
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, \
T2 *val, \
size_t k, \
size_t arrsize, \
bool hasnan = false, \
bool descending = false); \
template <typename T> \
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr, \
size_t arrsize, \
bool hasnan = false, \
bool descending = false); \
template <typename T> \
XSS_HIDE_SYMBOL std::vector<size_t> \
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); \
}

namespace xss {
namespace avx512 {
// quicksort
template <typename T>
XSS_HIDE_SYMBOL void
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
// key-value quicksort
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
T2 *val,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// quickselect
template <typename T>
XSS_HIDE_SYMBOL void qselect(T *arr,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// key-value select
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
T2 *val,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// partial sort
template <typename T>
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// key-value partial sort
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
T2 *val,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// argsort
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// argselect
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t>
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
} // namespace avx512
namespace avx2 {
// quicksort
template <typename T>
XSS_HIDE_SYMBOL void
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
// key-value quicksort
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
T2 *val,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// quickselect
template <typename T>
XSS_HIDE_SYMBOL void qselect(T *arr,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// key-value select
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
T2 *val,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// partial sort
template <typename T>
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// key-value partial sort
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
T2 *val,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// argsort
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// argselect
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t>
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
} // namespace avx2
namespace scalar {
// quicksort
template <typename T>
XSS_HIDE_SYMBOL void
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
// key-value quicksort
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
T2 *val,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// quickselect
template <typename T>
XSS_HIDE_SYMBOL void qselect(T *arr,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// key-value select
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
T2 *val,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// partial sort
template <typename T>
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// key-value partial sort
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
T2 *val,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// argsort
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// argselect
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t>
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
} // namespace scalar
DECLAREALLFUNCS(avx512)
DECLAREALLFUNCS(avx2)
DECLAREALLFUNCS(scalar)
DECLAREALLFUNCS(fp16_spr)
DECLAREALLFUNCS(fp16_icl)
} // namespace xss
#endif
4 changes: 2 additions & 2 deletions lib/x86simdsort-spr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include "x86simdsort-internal.h"

namespace xss {
namespace avx512 {
namespace fp16_spr {
template <>
void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending)
{
Expand All @@ -27,5 +27,5 @@ namespace avx512 {
{
x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending);
}
} // namespace avx512
} // namespace fp16_spr
} // namespace xss
36 changes: 32 additions & 4 deletions lib/x86simdsort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ namespace x86simdsort {
return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan); \
}

/* simple constexpr function as a way around having #ifdef __FLT16_MAX__ block
* within the DISPATCH macro */
template <typename T>
constexpr bool IS_TYPE_FLOAT16()
{
#ifdef __FLT16_MAX__
if constexpr (std::is_same_v<T, _Float16>) { return true; }
#endif
return false;
}

/* runtime dispatch mechanism */
#define DISPATCH(func, TYPE, ISA) \
DECLARE_INTERNAL_##func(TYPE) static __attribute__((constructor)) void \
Expand All @@ -118,7 +129,24 @@ namespace x86simdsort {
std::string_view preferred_cpu = find_preferred_cpu(ISA); \
if constexpr (dispatch_requested("avx512", ISA)) { \
if (preferred_cpu.find("avx512") != std::string_view::npos) { \
CAT(CAT(internal_, func), TYPE) = &xss::avx512::func<TYPE>; \
if constexpr (IS_TYPE_FLOAT16<TYPE>()) { \
if (preferred_cpu.find("avx512_spr") \
!= std::string_view::npos) { \
CAT(CAT(internal_, func), TYPE) \
= &xss::fp16_spr::func<TYPE>; \
return; \
} \
if (preferred_cpu.find("avx512_icl") \
!= std::string_view::npos) { \
CAT(CAT(internal_, func), TYPE) \
= &xss::fp16_icl::func<TYPE>; \
return; \
} \
} \
else { \
CAT(CAT(internal_, func), TYPE) \
= &xss::avx512::func<TYPE>; \
} \
return; \
} \
} \
Expand All @@ -137,9 +165,9 @@ namespace x86simdsort {
}

#ifdef __FLT16_MAX__
DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr"))
DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr"))
DISPATCH(partial_qsort, _Float16, ISA_LIST("avx512_spr"))
DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
DISPATCH(partial_qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
DISPATCH(argsort, _Float16, ISA_LIST("none"))
DISPATCH(argselect, _Float16, ISA_LIST("none"))
#endif
Expand Down
Loading