@@ -51,29 +51,6 @@ __device__ float atomicMax(float* address, float val) {
5151 return __int_as_float (old);
5252}
5353
54- __device__ float dDequantizeFP4Tree (unsigned char val, float absmax) {
55- float sign = (val & 0b1000 ) == 8 ? -1 .0f : 1 .0f ;
56- if ((val & 0b0100 ) == 4 ) // 0
57- if ((val & 0b0010 ) == 2 ) // 01
58- if ((val & 0b0001 ) == 1 ) // 111
59- return 0 .25000000f * absmax * sign; // 1111
60- else
61- return 0 .16666667f * absmax * sign; // 1110
62- else if ((val & 0b0001 ) == 1 ) // 110
63- return 0 .50000000f * absmax * sign; // 1101
64- else
65- return 0 .33333333f * absmax * sign; // 1100
66- else if ((val & 0b0010 ) == 2 ) // 10
67- if ((val & 0b0001 ) == 1 ) // 101
68- return 1 .00000000f * absmax * sign; // 1011
69- else
70- return 0 .66666667f * absmax * sign; // 1010
71- else if ((val & 0b0001 ) == 1 ) // 100
72- return 5 .208333333e-03f * absmax * sign; // 1001
73- else
74- return 0 .00000000f * absmax * sign; // 1000
75- }
76-
7754__device__ unsigned char dQuantizeFP4 (float x) {
7855 // FP4 with bias of 3
7956 // first bit is a sign
@@ -118,52 +95,6 @@ __device__ unsigned char dQuantizeFP4(float x) {
11895 return 0b0000 + sign;
11996}
12097
121- __device__ __forceinline__ float dDequantizeNF4 (unsigned char val) {
122-
123- // the values for this tree was generated by test_normal_map_tree
124- // in the file tests/test_functional.py
125- if ((val & 0b1000 ) == 8 )
126- if ((val & 0b0100 ) == 4 ) // 1
127- if ((val & 0b0010 ) == 2 ) // 11
128- if ((val & 0b0001 ) == 1 ) // 111
129- return 1 .0f ;
130- else
131- return 0 .7229568362236023f ;
132- else if ((val & 0b0001 ) == 1 ) // 110
133- return 0 .5626170039176941f ;
134- else
135- return 0 .44070982933044434f ;
136- else if ((val & 0b0010 ) == 2 ) // 10
137- if ((val & 0b0001 ) == 1 ) // 101
138- return 0 .33791524171829224f ;
139- else
140- return 0 .24611230194568634f ;
141- else if ((val & 0b0001 ) == 1 ) // 100
142- return 0 .16093020141124725f ;
143- else
144- return 0 .07958029955625534f ;
145-
146- else if ((val & 0b0100 ) == 4 ) // 0
147- if ((val & 0b0010 ) == 2 ) // 01
148- if ((val & 0b0001 ) == 1 ) // 011
149- return 0 .0f ;
150- else
151- return -0 .09105003625154495f ;
152- else if ((val & 0b0001 ) == 1 ) // 010
153- return -0 .18477343022823334f ;
154- else
155- return -0 .28444138169288635f ;
156- else if ((val & 0b0010 ) == 2 ) // 00
157- if ((val & 0b0001 ) == 1 ) // 001
158- return -0 .39491748809814453f ;
159- else
160- return -0 .5250730514526367f ;
161- else if ((val & 0b0001 ) == 1 ) // 000
162- return -0 .6961928009986877f ;
163- else
164- return -1 .0f ;
165- }
166-
16798__device__ unsigned char dQuantizeNF4 (float x) {
16899
169100 // the values for this tree was generated by test_normal_map_tree
@@ -468,6 +399,8 @@ template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
468399__global__ void
469400 kDequantizeBlockwise (float * code, unsigned char * A, float * absmax, T* out, const int blocksize, const int n) {
470401
402+ const int lane_id = threadIdx .x & 31 ;
403+
471404 const int n_load = (gridDim .x * TILE_SIZE);
472405 int valid_items_load = 0 ;
473406 int valid_items_store = 0 ;
@@ -483,8 +416,122 @@ __global__ void
483416 __shared__ typename LoadChar::TempStorage loadchar;
484417 __shared__ typename StoreT::TempStorage storet;
485418
419+ // Each thread in the warp holds one 4-bit LUT value for cooperative shuffling
420+ float my_lut_val;
421+ if constexpr (DATA_TYPE == NF4) {
422+ // NF4 lookup table
423+ switch (lane_id & 0xF ) {
424+ case 0 :
425+ my_lut_val = -1 .0f ;
426+ break ;
427+ case 1 :
428+ my_lut_val = -0 .6961928009986877f ;
429+ break ;
430+ case 2 :
431+ my_lut_val = -0 .5250730514526367f ;
432+ break ;
433+ case 3 :
434+ my_lut_val = -0 .39491748809814453f ;
435+ break ;
436+ case 4 :
437+ my_lut_val = -0 .28444138169288635f ;
438+ break ;
439+ case 5 :
440+ my_lut_val = -0 .18477343022823334f ;
441+ break ;
442+ case 6 :
443+ my_lut_val = -0 .09105003625154495f ;
444+ break ;
445+ case 7 :
446+ my_lut_val = 0 .0f ;
447+ break ;
448+ case 8 :
449+ my_lut_val = 0 .07958029955625534f ;
450+ break ;
451+ case 9 :
452+ my_lut_val = 0 .16093020141124725f ;
453+ break ;
454+ case 10 :
455+ my_lut_val = 0 .24611230194568634f ;
456+ break ;
457+ case 11 :
458+ my_lut_val = 0 .33791524171829224f ;
459+ break ;
460+ case 12 :
461+ my_lut_val = 0 .44070982933044434f ;
462+ break ;
463+ case 13 :
464+ my_lut_val = 0 .5626170039176941f ;
465+ break ;
466+ case 14 :
467+ my_lut_val = 0 .7229568362236023f ;
468+ break ;
469+ case 15 :
470+ my_lut_val = 1 .0f ;
471+ break ;
472+ default :
473+ my_lut_val = 0 .0f ;
474+ break ;
475+ }
476+ } else if constexpr (DATA_TYPE == FP4) {
477+ // FP4 lookup table
478+ switch (lane_id & 0xF ) {
479+ case 0 :
480+ my_lut_val = 0 .00000000f ;
481+ break ;
482+ case 1 :
483+ my_lut_val = 0 .00520833f ;
484+ break ;
485+ case 2 :
486+ my_lut_val = 0 .66666667f ;
487+ break ;
488+ case 3 :
489+ my_lut_val = 1 .00000000f ;
490+ break ;
491+ case 4 :
492+ my_lut_val = 0 .33333333f ;
493+ break ;
494+ case 5 :
495+ my_lut_val = 0 .50000000f ;
496+ break ;
497+ case 6 :
498+ my_lut_val = 0 .16666667f ;
499+ break ;
500+ case 7 :
501+ my_lut_val = 0 .25000000f ;
502+ break ;
503+ case 8 :
504+ my_lut_val = 0 .00000000f ;
505+ break ;
506+ case 9 :
507+ my_lut_val = -0 .00520833f ;
508+ break ;
509+ case 10 :
510+ my_lut_val = -0 .66666667f ;
511+ break ;
512+ case 11 :
513+ my_lut_val = -1 .00000000f ;
514+ break ;
515+ case 12 :
516+ my_lut_val = -0 .33333333f ;
517+ break ;
518+ case 13 :
519+ my_lut_val = -0 .50000000f ;
520+ break ;
521+ case 14 :
522+ my_lut_val = -0 .16666667f ;
523+ break ;
524+ case 15 :
525+ my_lut_val = -0 .25000000f ;
526+ break ;
527+ default :
528+ my_lut_val = 0 .00000000f ;
529+ break ;
530+ }
531+ }
532+
486533 for (int i = base_idx; i < n_load; i += gridDim .x * TILE_SIZE) {
487- if (DATA_TYPE > 0 ) {
534+ if constexpr (DATA_TYPE > 0 ) {
488535 valid_items_load = min (TILE_SIZE, (n + 1 ) / 2 - i);
489536 valid_items_store = min (TILE_SIZE * 2 , n - i * 2 );
490537 } else {
@@ -508,17 +555,16 @@ __global__ void
508555 vals[j] = __ldg (&code[qvals[j]]) * local_abs_max;
509556 break ;
510557 case FP4:
511- #pragma unroll NUM_PER_TH
512- for (int j = 0 ; j < NUM_PER_TH; j++) {
513- vals[j * 2 ] = dDequantizeFP4Tree (qvals[j] >> 4 , local_abs_max);
514- vals[j * 2 + 1 ] = dDequantizeFP4Tree (qvals[j] & 0x0F , local_abs_max);
515- }
516- break ;
517558 case NF4:
559+ // Each warp will cooperatively shuffle the LUT values
560+ // so that each thread has access to all 16 possible values.
561+ // This avoids the need for shared memory and branches.
518562#pragma unroll NUM_PER_TH
519563 for (int j = 0 ; j < NUM_PER_TH; j++) {
520- vals[j * 2 ] = dDequantizeNF4 (qvals[j] >> 4 ) * local_abs_max;
521- vals[j * 2 + 1 ] = dDequantizeNF4 (qvals[j] & 0x0F ) * local_abs_max;
564+ const unsigned char high_val = qvals[j] >> 4 ;
565+ const unsigned char low_val = qvals[j] & 0x0F ;
566+ vals[j * 2 ] = __shfl_sync (0xFFFFFFFF , my_lut_val, high_val) * local_abs_max;
567+ vals[j * 2 + 1 ] = __shfl_sync (0xFFFFFFFF , my_lut_val, low_val) * local_abs_max;
522568 }
523569 break ;
524570 }
0 commit comments