Skip to content

Commit d7e981d

Browse files
committed
rm _Float16
Signed-off-by: jiqing-feng <[email protected]>
1 parent 6bcd19e commit d7e981d

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

csrc/cpu_ops.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,17 @@ void dequantizeBlockwise4bitCpu(unsigned char* A,
158158

159159
if constexpr (std::is_same<T, bf16_t>::value) {
160160
out[block_idx + i] = float_to_bf16(v0);
161+
} else if constexpr (std::is_same<T, fp16_t>::value) {
162+
out[block_idx + i] = float_to_fp16(v0);
161163
} else {
162164
out[block_idx + i] = static_cast<T>(v0);
163165
}
164166

165167
if (i + 1 < valid_items) {
166168
if constexpr (std::is_same<T, bf16_t>::value) {
167169
out[block_idx + i + 1] = float_to_bf16(v1);
170+
} else if constexpr (std::is_same<T, fp16_t>::value) {
171+
out[block_idx + i + 1] = float_to_fp16(v1);
168172
} else {
169173
out[block_idx + i + 1] = static_cast<T>(v1);
170174
}
@@ -192,6 +196,8 @@ void dequantizeBlockwise8bitCpu(float* code,
192196
float v = code[A[i]] * scale;
193197
if constexpr (std::is_same<T, bf16_t>::value) {
194198
out[i] = float_to_bf16(v);
199+
} else if constexpr (std::is_same<T, fp16_t>::value) {
200+
out[i] = float_to_fp16(v);
195201
} else {
196202
out[i] = static_cast<T>(v);
197203
}

csrc/cpu_ops.h

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
#ifndef BITSANDBYTES_CPU_OPS_H
22
#define BITSANDBYTES_CPU_OPS_H
33

4-
#include <iostream>
5-
#include <stdio.h>
64
#include <cstdint>
75
#include <cstring>
8-
#include <type_traits>
96

107
void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n);
118

@@ -14,7 +11,9 @@ typedef enum DataType_t {
1411
FP4 = 1,
1512
} DataType_t;
1613

17-
using fp16_t = _Float16;
14+
struct fp16_t {
15+
uint16_t v;
16+
};
1817

1918
struct bf16_t {
2019
uint16_t v;
@@ -27,6 +26,48 @@ static inline bf16_t float_to_bf16(float x) {
2726
return bf16_t{static_cast<uint16_t>(r >> 16)};
2827
}
2928

29+
static inline fp16_t float_to_fp16(float x) {
30+
uint32_t bits;
31+
std::memcpy(&bits, &x, 4);
32+
uint32_t sign = (bits >> 31) & 0x1;
33+
uint32_t exp = (bits >> 23) & 0xFF;
34+
uint32_t mant = bits & 0x7FFFFF;
35+
36+
uint16_t h;
37+
if (exp == 0xFF) { // Inf / NaN
38+
uint16_t mant16 = mant ? 0x200 : 0; // quiet NaN: set MSB of mantissa
39+
h = (sign << 15) | (0x1F << 10) | mant16;
40+
} else if (exp > 0x70 + 0x1E) { // overflow: exp_f -127 +15 > 30 (exp_f > 142)
41+
h = (sign << 15) | (0x1F << 10); // Inf
42+
} else if (exp < 0x71) { // subnormal or zero (exp_f < 113)
43+
if (exp < 0x67) { // too small -> zero (exp_f < 103)
44+
h = (sign << 15);
45+
} else {
46+
// subnormal: implicit leading 1
47+
uint32_t shift = 0x71 - exp;
48+
uint32_t mant_with_hidden = mant | 0x800000;
49+
// add rounding bias before shifting (23-10 =13 bits to drop + shift)
50+
uint32_t rounded = (mant_with_hidden + (1u << (shift + 12))) >> (shift + 13);
51+
h = (sign << 15) | (uint16_t)rounded;
52+
}
53+
} else {
54+
// normalized
55+
uint32_t exp_h = exp - 127 + 15;
56+
// round mantissa: add 2^(23-10-1) = 0x1000
57+
uint32_t mant_rounded = mant + 0x00001000;
58+
if (mant_rounded & 0x00800000) { // mantissa overflow after rounding
59+
mant_rounded = 0;
60+
++exp_h;
61+
if (exp_h >= 0x1F) { // overflow to Inf
62+
h = (sign << 15) | (0x1F << 10);
63+
return fp16_t{h};
64+
}
65+
}
66+
h = (sign << 15) | ((uint16_t)exp_h << 10) | ((uint16_t)(mant_rounded >> 13));
67+
}
68+
return fp16_t{h};
69+
}
70+
3071
inline float dDequantizeFP4(unsigned char val) {
3172
if ((val & 0b1000) == 8)
3273
if ((val & 0b0100) == 4)

0 commit comments

Comments
 (0)