2121#define NUM 4
2222#define NUM_BLOCK 4096
2323
24- __device__ static float nf4_data[16 ] = {
25- -1.0 ,
26- -0.6961928009986877 ,
27- -0.5250730514526367 ,
28- -0.39491748809814453 ,
29- -0.28444138169288635 ,
30- -0.18477343022823334 ,
31- -0.09105003625154495 ,
32- 0.0 ,
33- 0.07958029955625534 ,
34- 0.16093020141124725 ,
35- 0.24611230194568634 ,
36- 0.33791524171829224 ,
37- 0.44070982933044434 ,
38- 0.5626170039176941 ,
39- 0.7229568362236023 ,
40- 1.0
24+ __device__ static float fp4_dequantization_lut[8 ] = {
25+ 0 .0f , // 0b000
26+ 0 .005208333333f , // 0b001
27+ 0 .66666667f , // 0b010
28+ 1 .0f , // 0b011
29+ 0 .33333333f , // 0b100
30+ 0 .5f , // 0b101
31+ 0 .16666667f , // 0b110
32+ 0 .25f // 0b111
33+ };
34+
35+ __device__ static float nf4_dequantization_lut[16 ] = {
36+ -1 .0f , // 0b0000
37+ -0 .6961928009986877f , // 0b0001
38+ -0 .5250730514526367f , // 0b0010
39+ -0 .39491748809814453f , // 0b0011
40+ -0 .28444138169288635f , // 0b0100
41+ -0 .18477343022823334f , // 0b0101
42+ -0 .09105003625154495f , // 0b0110
43+ 0 .0f , // 0b0111
44+ 0 .07958029955625534f , // 0b1000
45+ 0 .16093020141124725f , // 0b1001
46+ 0 .24611230194568634f , // 0b1010
47+ 0 .33791524171829224f , // 0b1011
48+ 0 .44070982933044434f , // 0b1100
49+ 0 .5626170039176941f , // 0b1101
50+ 0 .7229568362236023f , // 0b1110
51+ 1 .0f // 0b1111
4152};
4253
4354// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
@@ -51,27 +62,9 @@ __device__ float atomicMax(float* address, float val) {
5162 return __int_as_float (old);
5263}
5364
54- __device__ float dDequantizeFP4Tree (unsigned char val, float absmax) {
55- float sign = (val & 0b1000 ) == 8 ? -1 .0f : 1 .0f ;
56- if ((val & 0b0100 ) == 4 ) // 0
57- if ((val & 0b0010 ) == 2 ) // 01
58- if ((val & 0b0001 ) == 1 ) // 111
59- return 0 .25000000f * absmax * sign; // 1111
60- else
61- return 0 .16666667f * absmax * sign; // 1110
62- else if ((val & 0b0001 ) == 1 ) // 110
63- return 0 .50000000f * absmax * sign; // 1101
64- else
65- return 0 .33333333f * absmax * sign; // 1100
66- else if ((val & 0b0010 ) == 2 ) // 10
67- if ((val & 0b0001 ) == 1 ) // 101
68- return 1 .00000000f * absmax * sign; // 1011
69- else
70- return 0 .66666667f * absmax * sign; // 1010
71- else if ((val & 0b0001 ) == 1 ) // 100
72- return 5 .208333333e-03f * absmax * sign; // 1001
73- else
74- return 0 .00000000f * absmax * sign; // 1000
65+ __device__ __forceinline__ float dDequantizeFP4Tree (unsigned char val) {
66+ float sign = 1 .0f - 2 * ((val & 0b1000 ) >> 3 );
67+ return fp4_dequantization_lut[val & 0b111 ] * sign;
7568}
7669
7770__device__ unsigned char dQuantizeFP4 (float x) {
@@ -118,51 +111,7 @@ __device__ unsigned char dQuantizeFP4(float x) {
118111 return 0b0000 + sign;
119112}
120113
121- __device__ __forceinline__ float dDequantizeNF4 (unsigned char val) {
122-
123- // the values for this tree was generated by test_normal_map_tree
124- // in the file tests/test_functional.py
125- if ((val & 0b1000 ) == 8 )
126- if ((val & 0b0100 ) == 4 ) // 1
127- if ((val & 0b0010 ) == 2 ) // 11
128- if ((val & 0b0001 ) == 1 ) // 111
129- return 1 .0f ;
130- else
131- return 0 .7229568362236023f ;
132- else if ((val & 0b0001 ) == 1 ) // 110
133- return 0 .5626170039176941f ;
134- else
135- return 0 .44070982933044434f ;
136- else if ((val & 0b0010 ) == 2 ) // 10
137- if ((val & 0b0001 ) == 1 ) // 101
138- return 0 .33791524171829224f ;
139- else
140- return 0 .24611230194568634f ;
141- else if ((val & 0b0001 ) == 1 ) // 100
142- return 0 .16093020141124725f ;
143- else
144- return 0 .07958029955625534f ;
145-
146- else if ((val & 0b0100 ) == 4 ) // 0
147- if ((val & 0b0010 ) == 2 ) // 01
148- if ((val & 0b0001 ) == 1 ) // 011
149- return 0 .0f ;
150- else
151- return -0 .09105003625154495f ;
152- else if ((val & 0b0001 ) == 1 ) // 010
153- return -0 .18477343022823334f ;
154- else
155- return -0 .28444138169288635f ;
156- else if ((val & 0b0010 ) == 2 ) // 00
157- if ((val & 0b0001 ) == 1 ) // 001
158- return -0 .39491748809814453f ;
159- else
160- return -0 .5250730514526367f ;
161- else if ((val & 0b0001 ) == 1 ) // 000
162- return -0 .6961928009986877f ;
163- else
164- return -1 .0f ;
165- }
114+ __device__ __forceinline__ float dDequantizeNF4 (unsigned char val) { return nf4_dequantization_lut[val & 0x0F ]; }
166115
167116__device__ unsigned char dQuantizeNF4 (float x) {
168117
@@ -510,8 +459,8 @@ __global__ void
510459 case FP4:
511460#pragma unroll NUM_PER_TH
512461 for (int j = 0 ; j < NUM_PER_TH; j++) {
513- vals[j * 2 ] = dDequantizeFP4Tree (qvals[j] >> 4 , local_abs_max) ;
514- vals[j * 2 + 1 ] = dDequantizeFP4Tree (qvals[j] & 0x0F , local_abs_max) ;
462+ vals[j * 2 ] = dDequantizeFP4Tree (qvals[j] >> 4 ) * local_abs_max;
463+ vals[j * 2 + 1 ] = dDequantizeFP4Tree (qvals[j] & 0x0F ) * local_abs_max;
515464 }
516465 break ;
517466 case NF4:
@@ -2352,7 +2301,7 @@ __global__ void kgemm_4bit_inference(
23522301
23532302#pragma unroll 16
23542303 for (int i = 0 ; i < 16 ; i++)
2355- quant_map[i] = nf4_data [i];
2304+ quant_map[i] = nf4_dequantization_lut [i];
23562305 // __shared__ T quant_map[16*160];
23572306
23582307 T local_A[2 ];
0 commit comments