@@ -136,6 +136,49 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
136136 return sum;
137137}
138138
139+ static __device__ __forceinline__ int get_one_int_from_table_16 (const int & q4) {
140+ const uint8_t * q0_8 = (const uint8_t *) &q4;
141+ const char4 val0_8 = make_char4 (kvalues_iq4nl[q0_8[0 ]], kvalues_iq4nl[q0_8[1 ]], kvalues_iq4nl[q0_8[2 ]], kvalues_iq4nl[q0_8[3 ]]);
142+ return *((const int *) &val0_8);
143+ }
144+
145+ template <typename T, int D>
146+ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl (
147+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
148+
149+ const block_iq4_nl * K_iq4_nl = (const block_iq4_nl *) K_c;
150+ GGML_UNUSED (Q_v);
151+
152+ T sum = 0 .0f ;
153+
154+ #pragma unroll
155+ for (int k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE) {
156+ const int k_KQ = k_KQ_0 + threadIdx .x ;
157+
158+ const int ib = k_KQ / QI8_1;
159+ const int iqs4 = k_KQ % QI4_NL;
160+ const int shift = k_KQ & (QI8_1/2 );
161+
162+ const int v = get_one_int_from_table_16 ((get_int_b2 (K_iq4_nl[ib].qs , iqs4) >> shift) & 0x0F0F0F0F );
163+ const int u = Q_q8[k_KQ_0/WARP_SIZE];
164+
165+ const int sumi = ggml_cuda_dp4a (v, u, 0 );
166+
167+ #ifdef FP16_AVAILABLE
168+ if (std::is_same<T, half>::value) {
169+ const half2 * Q_ds = (const half2 *) Q_ds_v;
170+ sum += (T) (((half)sumi) * K_iq4_nl[ib].d * Q_ds[k_KQ_0/WARP_SIZE].x );
171+ } else
172+ #endif // FP16_AVAILABLE
173+ {
174+ const float2 * Q_ds = (const float2 *) Q_ds_v;
175+ sum += (T) ((float )sumi * __half2float (K_iq4_nl[ib].d ) * Q_ds[k_KQ_0/WARP_SIZE].x );
176+ }
177+ }
178+
179+ return sum;
180+ }
181+
139182template <typename T, int D>
140183static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0 (
141184 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -377,6 +420,25 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__
377420 return ((float ) d)*((float ) q);
378421}
379422
423+ template <typename T>
424+ static __device__ __forceinline__ T dequantize_1_iq4_nl (const void * __restrict__ vx, const int64_t i) {
425+ const block_iq4_nl * x = (const block_iq4_nl *) vx;
426+
427+ const int64_t ib = i / QK4_NL;
428+ const int iqs = i % (QK4_NL/2 );
429+ const int shift = (i % QK4_NL) / (QK4_NL/2 );
430+
431+ #ifdef FP16_AVAILABLE
432+ if constexpr (std::is_same<T, half>::value) {
433+ return x[ib].d * ((half) kvalues_iq4nl[(x[ib].qs [iqs] >> 4 *(shift)) & 0xf ]);
434+ } else {
435+ return (float )x[ib].d * ((float ) kvalues_iq4nl[(x[ib].qs [iqs] >> 4 *(shift)) & 0xf ]);
436+ }
437+ #endif
438+ T result = (float )x[ib].d * ((float ) kvalues_iq4nl[(x[ib].qs [iqs] >> 4 *(shift)) & 0xf ]);
439+ return result;
440+ }
441+
380442template <typename T>
381443static __device__ __forceinline__ T dequantize_1_q4_1 (const void * __restrict__ vx, const int64_t i) {
382444 const block_q4_1 * x = (const block_q4_1 *) vx;
@@ -476,44 +538,48 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
476538
477539template <int D>
478540constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16 (ggml_type type_K) {
479- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
480- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
481- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
482- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
483- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
484- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
485- nullptr ;
541+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
542+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
543+ type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<half, D> :
544+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
545+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
546+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
547+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
548+ nullptr ;
486549}
487550
488551template <int D>
489552constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32 (ggml_type type_K) {
490- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float , D> :
491- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float , D> :
492- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float , D> :
493- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float , D> :
494- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float , D> :
495- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float , D> :
496- nullptr ;
553+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float , D> :
554+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float , D> :
555+ type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<float , D> :
556+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float , D> :
557+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float , D> :
558+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float , D> :
559+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float , D> :
560+ nullptr ;
497561}
498562
499563constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16 (ggml_type type_V) {
500- return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
501- type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
502- type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
503- type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
504- type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
505- type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
506- nullptr ;
564+ return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
565+ type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
566+ type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
567+ type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
568+ type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
569+ type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<half> :
570+ type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
571+ nullptr ;
507572}
508573
509574constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32 (ggml_type type_V) {
510- return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float > :
511- type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float > :
512- type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float > :
513- type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float > :
514- type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float > :
515- type_V == GGML_TYPE_F16 ? dequantize_1_f16<float > :
516- nullptr ;
575+ return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float > :
576+ type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float > :
577+ type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float > :
578+ type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float > :
579+ type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float > :
580+ type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<float > :
581+ type_V == GGML_TYPE_F16 ? dequantize_1_f16<float > :
582+ nullptr ;
517583}
518584
519585template <int D, int parallel_blocks> // D == head size
@@ -569,10 +635,12 @@ static void on_no_fattn_vec_case(const int D) {
569635 } else if (D == 128 ) {
570636 fprintf (stderr, " Unsupported KV type combination for head_size 128.\n " );
571637 fprintf (stderr, " Supported combinations:\n " );
572- fprintf (stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n " );
573- fprintf (stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n " );
574- fprintf (stderr, " - K == f16, V == f16, 16.00 BPV\n " );
575- fprintf (stderr, " Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n " );
638+ fprintf (stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n " );
639+ fprintf (stderr, " - K == iq4_nl, V == iq4_nl, 4.50 BPV\n " );
640+ fprintf (stderr, " - K == q8_0, V == iq4_nl, 6.50 BPV\n " );
641+ fprintf (stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n " );
642+ fprintf (stderr, " - K == f16, V == f16, 16.00 BPV\n " );
643+ fprintf (stderr, " Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n " );
576644 GGML_ABORT (" fatal error" );
577645 } else {
578646 fprintf (stderr, " Unsupported KV type combination for head_size 256.\n " );
0 commit comments