@@ -64,12 +64,12 @@ struct zmm_vector<float16> {
64
64
65
65
static opmask_t ge (zmm_t x, zmm_t y)
66
66
{
67
- zmm_t sign_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x8000 ));
68
- zmm_t sign_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x8000 ));
69
- zmm_t exp_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x7c00 ));
70
- zmm_t exp_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x7c00 ));
71
- zmm_t mant_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x3ff ));
72
- zmm_t mant_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x3ff ));
67
+ zmm_t sign_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x8000 ));
68
+ zmm_t sign_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x8000 ));
69
+ zmm_t exp_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x7c00 ));
70
+ zmm_t exp_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x7c00 ));
71
+ zmm_t mant_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x3ff ));
72
+ zmm_t mant_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x3ff ));
73
73
74
74
__mmask32 mask_ge = _mm512_cmp_epu16_mask (sign_x, sign_y, _MM_CMPINT_LT); // only greater than
75
75
__mmask32 sign_eq = _mm512_cmpeq_epu16_mask (sign_x, sign_y);
@@ -595,8 +595,8 @@ template <>
595
595
bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
596
596
{
597
597
uint16_t signa = a & 0x8000 , signb = b & 0x8000 ;
598
- uint16_t expa = a & 0x7c00 , expb = b & 0x7c00 ;
599
- uint16_t manta = a & 0x3ff , mantb = b & 0x3ff ;
598
+ uint16_t expa = a & 0x7c00 , expb = b & 0x7c00 ;
599
+ uint16_t manta = a & 0x3ff , mantb = b & 0x3ff ;
600
600
if (signa != signb) {
601
601
// opposite signs
602
602
return a > b;
@@ -605,7 +605,7 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
605
605
// both -ve
606
606
if (expa != expb) {
607
607
return expa > expb;
608
- }
608
+ }
609
609
else {
610
610
return manta > mantb;
611
611
}
@@ -614,12 +614,12 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
614
614
// both +ve
615
615
if (expa != expb) {
616
616
return expa < expb;
617
- }
617
+ }
618
618
else {
619
619
return manta < mantb;
620
620
}
621
621
}
622
-
622
+
623
623
// return npy_half_to_float(a) < npy_half_to_float(b);
624
624
}
625
625
0 commit comments