Skip to content

Commit 2c4927d

Browse files
committed
Restore FP4 quantization/dequantization order
1 parent 4424b73 commit 2c4927d

File tree

2 files changed

+8
-31
lines changed

2 files changed

+8
-31
lines changed

bitsandbytes/functional.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -757,29 +757,6 @@ def get_4bit_type(typename, device=None, blocksize=64):
757757
# 0b111 = 3
758758
# can also be created with bnb.functional.create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4)
759759
data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0]
760-
if "cuda" in str(device).lower():
761-
# directly using the normalized (value/absmax) bins here
762-
# sorted [ascending order in positive range] + [ascending order in negative range]
763-
# to allow for faster quantization/dequantization on CUDA
764-
data = [
765-
0,
766-
0.005208333333,
767-
0.16666667,
768-
0.25,
769-
0.33333333,
770-
0.5,
771-
0.66666667,
772-
1.0,
773-
0,
774-
-0.005208333333,
775-
-0.16666667,
776-
-0.25,
777-
-0.33333333,
778-
-0.5,
779-
-0.66666667,
780-
-1.0,
781-
]
782-
783760
elif typename == "int4":
784761
data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7]
785762
elif typename == "af4":

csrc/kernels.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
__device__ static float fp4_dequantization_lut[8] = {
2525
0.0f, // 0b000
2626
0.005208333333f, // 0b001
27-
0.16666667f, // 0b010
28-
0.25f, // 0b011
27+
0.66666667f, // 0b010
28+
1.0f, // 0b011
2929
0.33333333f, // 0b100
3030
0.5f, // 0b101
31-
0.66666667f, // 0b110
32-
1.0f // 0b111
31+
0.16666667f, // 0b110
32+
0.25f // 0b111
3333
};
3434

3535
__device__ static float nf4_dequantization_lut[16] = {
@@ -93,18 +93,18 @@ __device__ unsigned char dQuantizeFP4(float x) {
9393
if (x > 0.29166667f)
9494
if (x > 0.583333f)
9595
if (x > 0.8333333f)
96-
return 0b0111 + sign;
96+
return 0b0011 + sign;
9797
else
98-
return 0b0110 + sign;
98+
return 0b0010 + sign;
9999
else if (x > 0.4166667f)
100100
return 0b101 + sign;
101101
else
102102
return 0b100 + sign;
103103
else if (x > 0.0859375f)
104104
if (x > 0.20833333f)
105-
return 0b011 + sign;
105+
return 0b0111 + sign;
106106
else
107-
return 0b0010 + sign;
107+
return 0b0110 + sign;
108108
else if (x > 0.00260417f)
109109
return 0b0001 + sign;
110110
else

0 commit comments

Comments
 (0)