11#ifndef BITSANDBYTES_CPU_OPS_H
22#define BITSANDBYTES_CPU_OPS_H
33
4- #include < iostream>
5- #include < stdio.h>
64#include < cstdint>
75#include < cstring>
8- #include < type_traits>
96
107void quantize_cpu (float * code, float * A, float * absmax, unsigned char * out, long long blocksize, long long n);
118
@@ -14,7 +11,9 @@ typedef enum DataType_t {
1411 FP4 = 1 ,
1512} DataType_t;
1613
17- using fp16_t = _Float16;
14+ struct fp16_t {
15+ uint16_t v;
16+ };
1817
1918struct bf16_t {
2019 uint16_t v;
@@ -27,6 +26,48 @@ static inline bf16_t float_to_bf16(float x) {
2726 return bf16_t {static_cast <uint16_t >(r >> 16 )};
2827}
2928
29+ static inline fp16_t float_to_fp16 (float x) {
30+ uint32_t bits;
31+ std::memcpy (&bits, &x, 4 );
32+ uint32_t sign = (bits >> 31 ) & 0x1 ;
33+ uint32_t exp = (bits >> 23 ) & 0xFF ;
34+ uint32_t mant = bits & 0x7FFFFF ;
35+
36+ uint16_t h;
37+ if (exp == 0xFF ) { // Inf / NaN
38+ uint16_t mant16 = mant ? 0x200 : 0 ; // quiet NaN: set MSB of mantissa
39+ h = (sign << 15 ) | (0x1F << 10 ) | mant16;
40+ } else if (exp > 0x70 + 0x1E ) { // overflow: exp_f -127 +15 > 30 (exp_f > 142)
41+ h = (sign << 15 ) | (0x1F << 10 ); // Inf
42+ } else if (exp < 0x71 ) { // subnormal or zero (exp_f < 113)
43+ if (exp < 0x67 ) { // too small -> zero (exp_f < 103)
44+ h = (sign << 15 );
45+ } else {
46+ // subnormal: implicit leading 1
47+ uint32_t shift = 0x71 - exp;
48+ uint32_t mant_with_hidden = mant | 0x800000 ;
49+ // add rounding bias before shifting (23-10 =13 bits to drop + shift)
50+ uint32_t rounded = (mant_with_hidden + (1u << (shift + 12 ))) >> (shift + 13 );
51+ h = (sign << 15 ) | (uint16_t )rounded;
52+ }
53+ } else {
54+ // normalized
55+ uint32_t exp_h = exp - 127 + 15 ;
56+ // round mantissa: add 2^(23-10-1) = 0x1000
57+ uint32_t mant_rounded = mant + 0x00001000 ;
58+ if (mant_rounded & 0x00800000 ) { // mantissa overflow after rounding
59+ mant_rounded = 0 ;
60+ ++exp_h;
61+ if (exp_h >= 0x1F ) { // overflow to Inf
62+ h = (sign << 15 ) | (0x1F << 10 );
63+ return fp16_t {h};
64+ }
65+ }
66+ h = (sign << 15 ) | ((uint16_t )exp_h << 10 ) | ((uint16_t )(mant_rounded >> 13 ));
67+ }
68+ return fp16_t {h};
69+ }
70+
3071inline float dDequantizeFP4 (unsigned char val) {
3172 if ((val & 0b1000 ) == 8 )
3273 if ((val & 0b0100 ) == 4 )
0 commit comments