@@ -64,25 +64,33 @@ struct zmm_vector<float16> {
64
64
65
65
static opmask_t ge (zmm_t x, zmm_t y)
66
66
{
67
- zmm_t sign_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x8000 ));
68
- zmm_t sign_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x8000 ));
69
- zmm_t exp_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x7c00 ));
70
- zmm_t exp_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x7c00 ));
71
- zmm_t mant_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x3ff ));
72
- zmm_t mant_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x3ff ));
73
-
74
- __mmask32 mask_ge = _mm512_cmp_epu16_mask (sign_x, sign_y, _MM_CMPINT_LT); // only greater than
75
- __mmask32 sign_eq = _mm512_cmpeq_epu16_mask (sign_x, sign_y);
76
- __mmask32 neg = _mm512_mask_cmpeq_epu16_mask (sign_eq, sign_x, _mm512_set1_epi16 (0x8000 )); // both numbers are -ve
77
-
78
- // compare exponents only if signs are equal:
79
- mask_ge = mask_ge | _mm512_mask_cmp_epu16_mask (sign_eq, exp_x, exp_y, _MM_CMPINT_NLE);
80
- // get mask for elements for which both sign and exponents are equal:
81
- __mmask32 exp_eq = _mm512_mask_cmpeq_epu16_mask (sign_eq, exp_x, exp_y);
82
-
83
- // compare mantissa for elements for which both sign and expponent are equal:
84
- mask_ge = mask_ge | _mm512_mask_cmp_epu16_mask (exp_eq, mant_x, mant_y, _MM_CMPINT_NLT);
85
- return _kxor_mask32 (mask_ge, neg);
67
+ zmm_t sign_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x8000 ));
68
+ zmm_t sign_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x8000 ));
69
+ zmm_t exp_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x7c00 ));
70
+ zmm_t exp_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x7c00 ));
71
+ zmm_t mant_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x3ff ));
72
+ zmm_t mant_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x3ff ));
73
+
74
+ __mmask32 mask_ge = _mm512_cmp_epu16_mask (
75
+ sign_x, sign_y, _MM_CMPINT_LT); // only greater than
76
+ __mmask32 sign_eq = _mm512_cmpeq_epu16_mask (sign_x, sign_y);
77
+ __mmask32 neg = _mm512_mask_cmpeq_epu16_mask (
78
+ sign_eq,
79
+ sign_x,
80
+ _mm512_set1_epi16 (0x8000 )); // both numbers are -ve
81
+
82
+ // compare exponents only if signs are equal:
83
+ mask_ge = mask_ge
84
+ | _mm512_mask_cmp_epu16_mask (
85
+ sign_eq, exp_x, exp_y, _MM_CMPINT_NLE);
86
+ // get mask for elements for which both sign and exponents are equal:
87
+ __mmask32 exp_eq = _mm512_mask_cmpeq_epu16_mask (sign_eq, exp_x, exp_y);
88
+
89
+ // compare mantissa for elements for which both sign and expponent are equal:
90
+ mask_ge = mask_ge
91
+ | _mm512_mask_cmp_epu16_mask (
92
+ exp_eq, mant_x, mant_y, _MM_CMPINT_NLT);
93
+ return _kxor_mask32 (mask_ge, neg);
86
94
}
87
95
static zmm_t loadu (void const *mem)
88
96
{
@@ -549,8 +557,8 @@ X86_SIMD_SORT_FINLINE void sort_128_16bit(type_t *arr, int32_t N)
549
557
550
558
template <typename vtype, typename type_t >
551
559
X86_SIMD_SORT_FINLINE type_t get_pivot_16bit (type_t *arr,
552
- const int64_t left,
553
- const int64_t right)
560
+ const int64_t left,
561
+ const int64_t right)
554
562
{
555
563
// median of 32
556
564
int64_t size = (right - left) / 32 ;
@@ -598,26 +606,22 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
598
606
uint16_t expa = a & 0x7c00 , expb = b & 0x7c00 ;
599
607
uint16_t manta = a & 0x3ff , mantb = b & 0x3ff ;
600
608
if (signa != signb) {
601
- // opposite signs
602
- return a > b;
609
+ // opposite signs
610
+ return a > b;
603
611
}
604
612
else if (signa > 0 ) {
605
- // both -ve
606
- if (expa != expb) {
607
- return expa > expb;
608
- }
609
- else {
610
- return manta > mantb;
611
- }
613
+ // both -ve
614
+ if (expa != expb) { return expa > expb; }
615
+ else {
616
+ return manta > mantb;
617
+ }
612
618
}
613
619
else {
614
- // both +ve
615
- if (expa != expb) {
616
- return expa < expb;
617
- }
618
- else {
619
- return manta < mantb;
620
- }
620
+ // both +ve
621
+ if (expa != expb) { return expa < expb; }
622
+ else {
623
+ return manta < mantb;
624
+ }
621
625
}
622
626
623
627
// return npy_half_to_float(a) < npy_half_to_float(b);
@@ -653,7 +657,8 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
653
657
qsort_16bit_<vtype>(arr, pivot_index, right, max_iters - 1 );
654
658
}
655
659
656
- X86_SIMD_SORT_FINLINE int64_t replace_nan_with_inf (uint16_t *arr, int64_t arrsize)
660
+ X86_SIMD_SORT_FINLINE int64_t replace_nan_with_inf (uint16_t *arr,
661
+ int64_t arrsize)
657
662
{
658
663
int64_t nan_count = 0 ;
659
664
__mmask16 loadmask = 0xFFFF ;
0 commit comments