Skip to content

Commit 65b66b2

Browse files
Improve kDequantizeBlockwise kernel performance for NF4/FP4
1 parent 39dd847 commit 65b66b2

File tree

1 file changed

+124
-78
lines changed

1 file changed

+124
-78
lines changed

csrc/kernels.cu

Lines changed: 124 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -51,29 +51,6 @@ __device__ float atomicMax(float* address, float val) {
5151
return __int_as_float(old);
5252
}
5353

54-
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) {
55-
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
56-
if ((val & 0b0100) == 4) // 0
57-
if ((val & 0b0010) == 2) // 01
58-
if ((val & 0b0001) == 1) // 111
59-
return 0.25000000f * absmax * sign; // 1111
60-
else
61-
return 0.16666667f * absmax * sign; // 1110
62-
else if ((val & 0b0001) == 1) // 110
63-
return 0.50000000f * absmax * sign; // 1101
64-
else
65-
return 0.33333333f * absmax * sign; // 1100
66-
else if ((val & 0b0010) == 2) // 10
67-
if ((val & 0b0001) == 1) // 101
68-
return 1.00000000f * absmax * sign; // 1011
69-
else
70-
return 0.66666667f * absmax * sign; // 1010
71-
else if ((val & 0b0001) == 1) // 100
72-
return 5.208333333e-03f * absmax * sign; // 1001
73-
else
74-
return 0.00000000f * absmax * sign; // 1000
75-
}
76-
7754
__device__ unsigned char dQuantizeFP4(float x) {
7855
// FP4 with bias of 3
7956
// first bit is a sign
@@ -118,52 +95,6 @@ __device__ unsigned char dQuantizeFP4(float x) {
11895
return 0b0000 + sign;
11996
}
12097

121-
__device__ __forceinline__ float dDequantizeNF4(unsigned char val) {
122-
123-
// the values for this tree was generated by test_normal_map_tree
124-
// in the file tests/test_functional.py
125-
if ((val & 0b1000) == 8)
126-
if ((val & 0b0100) == 4) // 1
127-
if ((val & 0b0010) == 2) // 11
128-
if ((val & 0b0001) == 1) // 111
129-
return 1.0f;
130-
else
131-
return 0.7229568362236023f;
132-
else if ((val & 0b0001) == 1) // 110
133-
return 0.5626170039176941f;
134-
else
135-
return 0.44070982933044434f;
136-
else if ((val & 0b0010) == 2) // 10
137-
if ((val & 0b0001) == 1) // 101
138-
return 0.33791524171829224f;
139-
else
140-
return 0.24611230194568634f;
141-
else if ((val & 0b0001) == 1) // 100
142-
return 0.16093020141124725f;
143-
else
144-
return 0.07958029955625534f;
145-
146-
else if ((val & 0b0100) == 4) // 0
147-
if ((val & 0b0010) == 2) // 01
148-
if ((val & 0b0001) == 1) // 011
149-
return 0.0f;
150-
else
151-
return -0.09105003625154495f;
152-
else if ((val & 0b0001) == 1) // 010
153-
return -0.18477343022823334f;
154-
else
155-
return -0.28444138169288635f;
156-
else if ((val & 0b0010) == 2) // 00
157-
if ((val & 0b0001) == 1) // 001
158-
return -0.39491748809814453f;
159-
else
160-
return -0.5250730514526367f;
161-
else if ((val & 0b0001) == 1) // 000
162-
return -0.6961928009986877f;
163-
else
164-
return -1.0f;
165-
}
166-
16798
__device__ unsigned char dQuantizeNF4(float x) {
16899

169100
// the values for this tree was generated by test_normal_map_tree
@@ -468,6 +399,8 @@ template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
468399
__global__ void
469400
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n) {
470401

402+
const int lane_id = threadIdx.x & 31;
403+
471404
const int n_load = (gridDim.x * TILE_SIZE);
472405
int valid_items_load = 0;
473406
int valid_items_store = 0;
@@ -483,8 +416,122 @@ __global__ void
483416
__shared__ typename LoadChar::TempStorage loadchar;
484417
__shared__ typename StoreT::TempStorage storet;
485418

419+
// Each thread in the warp holds one 4-bit LUT value for cooperative shuffling
420+
float my_lut_val;
421+
if constexpr (DATA_TYPE == NF4) {
422+
// NF4 lookup table
423+
switch (lane_id & 0xF) {
424+
case 0:
425+
my_lut_val = -1.0f;
426+
break;
427+
case 1:
428+
my_lut_val = -0.6961928009986877f;
429+
break;
430+
case 2:
431+
my_lut_val = -0.5250730514526367f;
432+
break;
433+
case 3:
434+
my_lut_val = -0.39491748809814453f;
435+
break;
436+
case 4:
437+
my_lut_val = -0.28444138169288635f;
438+
break;
439+
case 5:
440+
my_lut_val = -0.18477343022823334f;
441+
break;
442+
case 6:
443+
my_lut_val = -0.09105003625154495f;
444+
break;
445+
case 7:
446+
my_lut_val = 0.0f;
447+
break;
448+
case 8:
449+
my_lut_val = 0.07958029955625534f;
450+
break;
451+
case 9:
452+
my_lut_val = 0.16093020141124725f;
453+
break;
454+
case 10:
455+
my_lut_val = 0.24611230194568634f;
456+
break;
457+
case 11:
458+
my_lut_val = 0.33791524171829224f;
459+
break;
460+
case 12:
461+
my_lut_val = 0.44070982933044434f;
462+
break;
463+
case 13:
464+
my_lut_val = 0.5626170039176941f;
465+
break;
466+
case 14:
467+
my_lut_val = 0.7229568362236023f;
468+
break;
469+
case 15:
470+
my_lut_val = 1.0f;
471+
break;
472+
default:
473+
my_lut_val = 0.0f;
474+
break;
475+
}
476+
} else if constexpr (DATA_TYPE == FP4) {
477+
// FP4 lookup table
478+
switch (lane_id & 0xF) {
479+
case 0:
480+
my_lut_val = 0.00000000f;
481+
break;
482+
case 1:
483+
my_lut_val = 0.00520833f;
484+
break;
485+
case 2:
486+
my_lut_val = 0.66666667f;
487+
break;
488+
case 3:
489+
my_lut_val = 1.00000000f;
490+
break;
491+
case 4:
492+
my_lut_val = 0.33333333f;
493+
break;
494+
case 5:
495+
my_lut_val = 0.50000000f;
496+
break;
497+
case 6:
498+
my_lut_val = 0.16666667f;
499+
break;
500+
case 7:
501+
my_lut_val = 0.25000000f;
502+
break;
503+
case 8:
504+
my_lut_val = 0.00000000f;
505+
break;
506+
case 9:
507+
my_lut_val = -0.00520833f;
508+
break;
509+
case 10:
510+
my_lut_val = -0.66666667f;
511+
break;
512+
case 11:
513+
my_lut_val = -1.00000000f;
514+
break;
515+
case 12:
516+
my_lut_val = -0.33333333f;
517+
break;
518+
case 13:
519+
my_lut_val = -0.50000000f;
520+
break;
521+
case 14:
522+
my_lut_val = -0.16666667f;
523+
break;
524+
case 15:
525+
my_lut_val = -0.25000000f;
526+
break;
527+
default:
528+
my_lut_val = 0.00000000f;
529+
break;
530+
}
531+
}
532+
486533
for (int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) {
487-
if (DATA_TYPE > 0) {
534+
if constexpr (DATA_TYPE > 0) {
488535
valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i);
489536
valid_items_store = min(TILE_SIZE * 2, n - i * 2);
490537
} else {
@@ -508,17 +555,16 @@ __global__ void
508555
vals[j] = __ldg(&code[qvals[j]]) * local_abs_max;
509556
break;
510557
case FP4:
511-
#pragma unroll NUM_PER_TH
512-
for (int j = 0; j < NUM_PER_TH; j++) {
513-
vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
514-
vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
515-
}
516-
break;
517558
case NF4:
559+
// Each warp will cooperatively shuffle the LUT values
560+
// so that each thread has access to all 16 possible values.
561+
// This avoids the need for shared memory and branches.
518562
#pragma unroll NUM_PER_TH
519563
for (int j = 0; j < NUM_PER_TH; j++) {
520-
vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max;
521-
vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max;
564+
const unsigned char high_val = qvals[j] >> 4;
565+
const unsigned char low_val = qvals[j] & 0x0F;
566+
vals[j * 2] = __shfl_sync(0xFFFFFFFF, my_lut_val, high_val) * local_abs_max;
567+
vals[j * 2 + 1] = __shfl_sync(0xFFFFFFFF, my_lut_val, low_val) * local_abs_max;
522568
}
523569
break;
524570
}

0 commit comments

Comments
 (0)