@@ -54,6 +54,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
5454 return MMQ_Q8_1_DS_LAYOUT_D4;
5555 case GGML_TYPE_Q5_1:
5656 return MMQ_Q8_1_DS_LAYOUT_DS4;
57+ case GGML_TYPE_Q6_0:
58+ return MMQ_Q8_1_DS_LAYOUT_D4;
5759 case GGML_TYPE_Q8_0:
5860 return MMQ_Q8_1_DS_LAYOUT_D4;
5961 case GGML_TYPE_Q2_K:
@@ -156,6 +158,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
156158 // type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
157159 type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
158160 type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
161+ type == GGML_TYPE_Q6_0 ? MMQ_DP4A_TXS_Q8_0 :
159162 type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
160163 type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
161164 type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
@@ -190,6 +193,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
190193 // type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
191194 type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
192195 type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
196+ type == GGML_TYPE_Q6_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
193197 type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
194198 type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
195199 type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
@@ -557,6 +561,69 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
557561 }
558562}
559563
564+ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_0 (
565+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
566+
567+ #ifdef INT8_MMA_AVAILABLE
568+ int * x_qs = (int *) x_tile;
569+ float * x_df = (float *) (x_qs + WARP_SIZE*2 );
570+ #else
571+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q6_0, mmq_y);
572+ int * x_qs = (int *) x_tile;
573+ float * x_df = (float *) (x_qs + txs.qs );
574+ #endif // INT8_MMA_AVAILABLE
575+
576+ const int kbx = threadIdx .x / QI6_0;
577+ const int kqsx = threadIdx .x % QI6_0;
578+
579+ #pragma unroll
580+ for (int i0 = 0 ; i0 < mmq_y; i0 += nwarps) {
581+ int i = i0 + threadIdx .y ;
582+
583+ if (need_check) {
584+ i = min (i, i_max);
585+ }
586+
587+ const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbx;
588+
589+ const int ql = get_int_b2 (bxi->qs , kqsx);
590+ const int qh = get_int_b2 (bxi->qh , kqsx%2 ) >> 4 *(kqsx/2 );
591+
592+ int qs0 = ((ql >> 0 ) & 0x0F0F0F0F ) | ((qh << 4 ) & 0x30303030 );
593+ int qs1 = ((ql >> 4 ) & 0x0F0F0F0F ) | ((qh << 2 ) & 0x30303030 );
594+ qs0 = __vsubss4 (qs0, 0x20202020 ); // subtract 32
595+ qs1 = __vsubss4 (qs1, 0x20202020 ); // subtract 32
596+
597+ #ifdef INT8_MMA_AVAILABLE
598+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2 *QI6_0) + kqsx + 0 ] = qs0;
599+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2 *QI6_0) + kqsx + QI6_0] = qs1;
600+ #else
601+ x_qs[i*(2 *WARP_SIZE + 1 ) + kbx*(2 *QI6_0) + kqsx + 0 ] = qs0;
602+ x_qs[i*(2 *WARP_SIZE + 1 ) + kbx*(2 *QI6_0) + kqsx + QI6_0] = qs1;
603+ #endif // INT8_MMA_AVAILABLE
604+ }
605+
606+ const int blocks_per_tile_x_row = WARP_SIZE / QI6_0;
607+ const int kbxd = threadIdx .x % blocks_per_tile_x_row;
608+
609+ #pragma unroll
610+ for (int i0 = 0 ; i0 < mmq_y; i0 += nwarps * QI6_0) {
611+ int i = i0 + threadIdx .y * QI6_0 + threadIdx .x / blocks_per_tile_x_row;
612+
613+ if (need_check) {
614+ i = min (i, i_max);
615+ }
616+
617+ const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbxd;
618+
619+ #ifdef INT8_MMA_AVAILABLE
620+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d ;
621+ #else
622+ x_df[i*(WARP_SIZE/QI6_0) + i/QI6_0 + kbxd] = bxi->d ;
623+ #endif // INT8_MMA_AVAILABLE
624+ }
625+ }
626+
560627template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0 (
561628 const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
562629
@@ -2380,6 +2447,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
23802447 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
23812448};
23822449
2450+ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2451+ struct mmq_type_traits <mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_0> {
2452+ static constexpr int vdr = VDR_Q6_0_Q8_1_MMQ;
2453+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_0<mmq_y, nwarps, need_check>;
2454+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2455+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2456+ };
2457+
23832458template <int mmq_x, int mmq_y, int nwarps, bool need_check>
23842459struct mmq_type_traits <mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
23852460 static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
@@ -2911,6 +2986,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
29112986// extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
29122987extern DECL_MMQ_CASE (GGML_TYPE_Q5_0);
29132988extern DECL_MMQ_CASE (GGML_TYPE_Q5_1);
2989+ extern DECL_MMQ_CASE (GGML_TYPE_Q6_0);
29142990extern DECL_MMQ_CASE (GGML_TYPE_Q8_0);
29152991extern DECL_MMQ_CASE (GGML_TYPE_Q2_K);
29162992extern DECL_MMQ_CASE (GGML_TYPE_Q3_K);
0 commit comments