Skip to content

Commit 6e7ec12

Browse files
Nexesenexikawrakow
andcommitted
Q6_0 MMQ Kernel attempt
MMQ for Q6_0 authored by Ikawrakow Add Q6_0 MMQ to template generator authored by Ikawrakow Co-Authored-By: Kawrakow <[email protected]>
1 parent 3788953 commit 6e7ec12

File tree

5 files changed

+87
-2
lines changed

5 files changed

+87
-2
lines changed

ggml/src/ggml-cuda/dmmv.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
637637
src1_dfloat = src1_dfloat_a.alloc(ne00);
638638
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
639639
GGML_ASSERT(to_fp16_cuda != nullptr);
640-
to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream);
640+
to_fp16_cuda(src1_ddf_i, src1_dfloat, 1, ne00, stream);
641641
}
642642
#else
643643
const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion

ggml/src/ggml-cuda/mmq.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ void ggml_cuda_op_mul_mat_q(
4343
case GGML_TYPE_Q5_1:
4444
mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
4545
break;
46+
case GGML_TYPE_Q6_0:
47+
mul_mat_q_case<GGML_TYPE_Q6_0>(ctx, args, stream);
48+
break;
4649
case GGML_TYPE_Q8_0:
4750
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
4851
break;
@@ -109,6 +112,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
109112
//case GGML_TYPE_Q4_1:
110113
case GGML_TYPE_Q5_0:
111114
case GGML_TYPE_Q5_1:
115+
case GGML_TYPE_Q6_0:
112116
case GGML_TYPE_Q8_0:
113117
case GGML_TYPE_Q2_K:
114118
case GGML_TYPE_Q3_K:

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
5454
return MMQ_Q8_1_DS_LAYOUT_D4;
5555
case GGML_TYPE_Q5_1:
5656
return MMQ_Q8_1_DS_LAYOUT_DS4;
57+
case GGML_TYPE_Q6_0:
58+
return MMQ_Q8_1_DS_LAYOUT_D4;
5759
case GGML_TYPE_Q8_0:
5860
return MMQ_Q8_1_DS_LAYOUT_D4;
5961
case GGML_TYPE_Q2_K:
@@ -156,6 +158,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
156158
//type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
157159
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
158160
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
161+
type == GGML_TYPE_Q6_0 ? MMQ_DP4A_TXS_Q8_0 :
159162
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
160163
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
161164
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
@@ -190,6 +193,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
190193
//type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
191194
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
192195
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
196+
type == GGML_TYPE_Q6_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
193197
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
194198
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
195199
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
@@ -557,6 +561,69 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
557561
}
558562
}
559563

564+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_0(
565+
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
566+
567+
#ifdef INT8_MMA_AVAILABLE
568+
int * x_qs = (int *) x_tile;
569+
float * x_df = (float *) (x_qs + WARP_SIZE*2);
570+
#else
571+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_0, mmq_y);
572+
int * x_qs = (int *) x_tile;
573+
float * x_df = (float *) (x_qs + txs.qs);
574+
#endif // INT8_MMA_AVAILABLE
575+
576+
const int kbx = threadIdx.x / QI6_0;
577+
const int kqsx = threadIdx.x % QI6_0;
578+
579+
#pragma unroll
580+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
581+
int i = i0 + threadIdx.y;
582+
583+
if (need_check) {
584+
i = min(i, i_max);
585+
}
586+
587+
const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbx;
588+
589+
const int ql = get_int_b2(bxi->qs, kqsx);
590+
const int qh = get_int_b2(bxi->qh, kqsx%2) >> 4*(kqsx/2);
591+
592+
int qs0 = ((ql >> 0) & 0x0F0F0F0F) | ((qh << 4) & 0x30303030);
593+
int qs1 = ((ql >> 4) & 0x0F0F0F0F) | ((qh << 2) & 0x30303030);
594+
qs0 = __vsubss4(qs0, 0x20202020); // subtract 32
595+
qs1 = __vsubss4(qs1, 0x20202020); // subtract 32
596+
597+
#ifdef INT8_MMA_AVAILABLE
598+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI6_0) + kqsx + 0] = qs0;
599+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI6_0) + kqsx + QI6_0] = qs1;
600+
#else
601+
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI6_0) + kqsx + 0] = qs0;
602+
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI6_0) + kqsx + QI6_0] = qs1;
603+
#endif // INT8_MMA_AVAILABLE
604+
}
605+
606+
const int blocks_per_tile_x_row = WARP_SIZE / QI6_0;
607+
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
608+
609+
#pragma unroll
610+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_0) {
611+
int i = i0 + threadIdx.y * QI6_0 + threadIdx.x / blocks_per_tile_x_row;
612+
613+
if (need_check) {
614+
i = min(i, i_max);
615+
}
616+
617+
const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbxd;
618+
619+
#ifdef INT8_MMA_AVAILABLE
620+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
621+
#else
622+
x_df[i*(WARP_SIZE/QI6_0) + i/QI6_0 + kbxd] = bxi->d;
623+
#endif // INT8_MMA_AVAILABLE
624+
}
625+
}
626+
560627
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
561628
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
562629

@@ -2380,6 +2447,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
23802447
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
23812448
};
23822449

2450+
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2451+
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_0> {
2452+
static constexpr int vdr = VDR_Q6_0_Q8_1_MMQ;
2453+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_0<mmq_y, nwarps, need_check>;
2454+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2455+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2456+
};
2457+
23832458
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
23842459
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
23852460
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
@@ -2911,6 +2986,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
29112986
//extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
29122987
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
29132988
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
2989+
extern DECL_MMQ_CASE(GGML_TYPE_Q6_0);
29142990
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
29152991
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
29162992
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);

ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
2525
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
2626
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
27-
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
27+
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_Q6_0"
2828
]
2929

3030
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmq.cuh"
4+
5+
DECL_MMQ_CASE(GGML_TYPE_Q6_0);

0 commit comments

Comments
 (0)