Skip to content

Commit 3adb194

Browse files
committed
Better errors for invalid types in sorting functions
1 parent 724e92e commit 3adb194

File tree

4 files changed

+73
-11
lines changed

4 files changed

+73
-11
lines changed

src/xss-common-argsort.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,14 @@ X86_SIMD_SORT_INLINE void argselect_(type_t *arr,
584584
arr, arg, pos, pivot_index, right, max_iters - 1);
585585
}
586586

587+
template <typename T, typename vtype>
588+
X86_SIMD_SORT_FINLINE bool is_sorted(T *arr, arrsize_t arrsize, bool descending)
589+
{
590+
auto comp = descending ? Comparator<vtype, true>::STDSortComparator
591+
: Comparator<vtype, false>::STDSortComparator;
592+
return std::is_sorted(arr, arr + arrsize, comp);
593+
}
594+
587595
/* argsort methods for 32-bit and 64-bit dtypes */
588596
template <typename T,
589597
template <typename...>
@@ -600,11 +608,12 @@ X86_SIMD_SORT_INLINE void xss_argsort(T *arr,
600608
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
601609
half_vector<T>,
602610
full_vector<T>>::type;
603-
604611
using argtype =
605612
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
606613
half_vector<arrsize_t>,
607614
full_vector<arrsize_t>>::type;
615+
static_assert(is_valid_vector_type_key_value<vectype, argtype>(),
616+
"Invalid type for argsort!");
608617

609618
if (arrsize > 1) {
610619
/* simdargsort does not work for float/double arrays with nan */
@@ -620,9 +629,7 @@ X86_SIMD_SORT_INLINE void xss_argsort(T *arr,
620629
UNUSED(hasnan);
621630

622631
/* early exit for already sorted arrays: float/double with nan never reach here*/
623-
auto comp = descending ? Comparator<vectype, true>::STDSortComparator
624-
: Comparator<vectype, false>::STDSortComparator;
625-
if (std::is_sorted(arr, arr + arrsize, comp)) { return; }
632+
if (is_sorted<T, vectype>(arr, arrsize, descending)) { return; }
626633

627634
#ifdef XSS_COMPILE_OPENMP
628635

@@ -708,11 +715,12 @@ X86_SIMD_SORT_INLINE void xss_argselect(T *arr,
708715
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
709716
half_vector<T>,
710717
full_vector<T>>::type;
711-
712718
using argtype =
713719
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
714720
half_vector<arrsize_t>,
715721
full_vector<arrsize_t>>::type;
722+
static_assert(is_valid_vector_type_key_value<vectype, argtype>(),
723+
"Invalid type for argselect!");
716724

717725
if (arrsize > 1) {
718726
if constexpr (xss::fp::is_floating_point_v<T>) {

src/xss-common-includes.h

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,19 +92,27 @@ constexpr bool always_false = false;
9292

9393
typedef size_t arrsize_t;
9494

95-
template <typename type>
96-
struct zmm_vector;
95+
enum class simd_type : int { INVALID, AVX2, AVX512 };
9796

9897
template <typename type>
99-
struct ymm_vector;
98+
struct zmm_vector {
99+
static constexpr simd_type vec_type = simd_type::INVALID;
100+
};
100101

101102
template <typename type>
102-
struct avx2_vector;
103+
struct ymm_vector {
104+
static constexpr simd_type vec_type = simd_type::INVALID;
105+
};
103106

104107
template <typename type>
105-
struct avx2_half_vector;
108+
struct avx2_vector {
109+
static constexpr simd_type vec_type = simd_type::INVALID;
110+
};
106111

107-
enum class simd_type : int { AVX2, AVX512 };
112+
template <typename type>
113+
struct avx2_half_vector {
114+
static constexpr simd_type vec_type = simd_type::INVALID;
115+
};
108116

109117
template <typename vtype, typename T = typename vtype::type_t>
110118
X86_SIMD_SORT_INLINE bool comparison_func(const T &a, const T &b);
@@ -113,4 +121,29 @@ struct float16 {
113121
uint16_t val;
114122
};
115123

124+
template <typename vtype>
125+
constexpr bool is_valid_vector_type()
126+
{
127+
return vtype::vec_type != simd_type::INVALID;
128+
}
129+
130+
template <typename vtype>
131+
constexpr bool is_valid_vector_type_32_or_64_bit()
132+
{
133+
if constexpr (is_valid_vector_type<vtype>()) {
134+
constexpr int type_size = sizeof(typename vtype::type_t);
135+
return type_size == 4 || type_size == 8;
136+
}
137+
else {
138+
return false;
139+
}
140+
}
141+
142+
template <typename vtype1, typename vtype2>
143+
constexpr bool is_valid_vector_type_key_value()
144+
{
145+
return is_valid_vector_type_32_or_64_bit<vtype1>()
146+
&& is_valid_vector_type_32_or_64_bit<vtype2>();
147+
}
148+
116149
#endif // XSS_COMMON_INCLUDES

src/xss-common-keyvaluesort.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,8 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv(
580580
&& sizeof(T2) == sizeof(int32_t),
581581
half_vector<T2>,
582582
full_vector<T2>>::type;
583+
static_assert(is_valid_vector_type_key_value<keytype, valtype>(),
584+
"Invalid type for keyvalue_qsort!");
583585

584586
// Exit early if no work would be done
585587
if (arrsize <= 1) return;
@@ -677,6 +679,8 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys,
677679
&& sizeof(T2) == sizeof(int32_t),
678680
half_vector<T2>,
679681
full_vector<T2>>::type;
682+
static_assert(is_valid_vector_type_key_value<keytype, valtype>(),
683+
"Invalid type for keyvalue_select!");
680684

681685
// Exit early if no work would be done
682686
if (arrsize <= 1) return;
@@ -732,6 +736,19 @@ X86_SIMD_SORT_INLINE void xss_partial_sort_kv(T1 *keys,
732736
bool hasnan,
733737
bool descending)
734738
{
739+
using keytype =
740+
typename std::conditional<sizeof(T1) != sizeof(T2)
741+
&& sizeof(T1) == sizeof(int32_t),
742+
half_vector<T1>,
743+
full_vector<T1>>::type;
744+
using valtype =
745+
typename std::conditional<sizeof(T1) != sizeof(T2)
746+
&& sizeof(T2) == sizeof(int32_t),
747+
half_vector<T2>,
748+
full_vector<T2>>::type;
749+
static_assert(is_valid_vector_type_key_value<keytype, valtype>(),
750+
"Invalid type for keyvalue_partial_sort!");
751+
735752
if (k == 0) return;
736753
xss_select_kv<T1, T2, full_vector, half_vector>(
737754
keys, indexes, k - 1, arrsize, hasnan, descending);

src/xss-common-qsort.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr,
652652
template <typename vtype, typename T, bool descending = false>
653653
X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan)
654654
{
655+
static_assert(is_valid_vector_type<vtype>(), "Invalid type for qsort!");
655656
using comparator =
656657
typename std::conditional<descending,
657658
Comparator<vtype, true>,
@@ -716,6 +717,7 @@ template <typename vtype, typename T, bool descending = false>
716717
X86_SIMD_SORT_INLINE void
717718
xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
718719
{
720+
static_assert(is_valid_vector_type<vtype>(), "Invalid type for qselect!");
719721
using comparator =
720722
typename std::conditional<descending,
721723
Comparator<vtype, true>,
@@ -758,6 +760,8 @@ template <typename vtype, typename T, bool descending = false>
758760
X86_SIMD_SORT_INLINE void
759761
xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
760762
{
763+
static_assert(is_valid_vector_type<vtype>(),
764+
"Invalid type for partial_qsort!");
761765
if (k == 0) return;
762766
xss_qselect<vtype, T, descending>(arr, k - 1, arrsize, hasnan);
763767
xss_qsort<vtype, T, descending>(arr, k - 1, hasnan);

0 commit comments

Comments
 (0)