Skip to content

Commit 84cb48c

Browse files
committed
Add q6_k mmq
1 parent c9382df commit 84cb48c

File tree

5 files changed

+69
-4
lines changed

5 files changed

+69
-4
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2976,6 +2976,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
29762976
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
29772977
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
29782978
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
2979+
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
29792980
}
29802981
#endif
29812982

@@ -3100,6 +3101,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
31003101
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
31013102
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
31023103
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
3104+
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
31033105
}
31043106
#endif
31053107

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,65 @@ void block_b_to_registers(const uint ib) {
440440
}
441441
#endif
442442

443+
#if defined(DATA_A_Q6_K)
444+
// 2-byte loads for Q6_K blocks (210 bytes)
445+
#ifdef MMQ_SHMEM
446+
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
447+
const uint ib_k = ib / 8;
448+
const uint iqs_k = (ib % 8) * 8 + iqs;
449+
450+
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
451+
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
452+
453+
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
454+
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
455+
456+
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
457+
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
458+
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
459+
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
460+
buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
461+
462+
if (iqs == 0) {
463+
const uint is = iqs_k / 4;
464+
const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
465+
466+
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
467+
}
468+
}
469+
470+
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
471+
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
472+
473+
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
474+
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
475+
}
476+
}
477+
478+
ACC_TYPE mmq_dot_product(const uint ib_a) {
479+
float result = 0.0;
480+
int32_t q_sum = 0;
481+
482+
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
483+
const int32_t qs_a = cache_a[ib_a].qs[iqs];
484+
485+
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
486+
}
487+
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
488+
q_sum = 0;
489+
490+
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
491+
const int32_t qs_a = cache_a[ib_a].qs[iqs];
492+
493+
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
494+
}
495+
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
496+
497+
return ACC_TYPE(cache_b.ds.x * result);
498+
}
499+
#endif // MMQ_SHMEM
500+
#endif
501+
443502
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
444503
FLOAT_TYPE get_d(uint ib) {
445504
return FLOAT_TYPE(data_a[ib].d);

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,16 @@ struct block_a_cache {
5353
};
5454
#elif defined(DATA_A_Q5_K)
5555
#define QUANT_R_MMQ 1
56-
// AMD likes 4, Intel likes 1 and Nvidia likes 2
57-
#define BK_STEP 4
5856
struct block_a_cache {
5957
int32_t qs[8];
6058
FLOAT_TYPE_VEC2 dm;
6159
};
60+
#elif defined(DATA_A_Q6_K)
61+
#define QUANT_R_MMQ 1
62+
struct block_a_cache {
63+
int32_t qs[8];
64+
FLOAT_TYPE_VEC2 d_scales;
65+
};
6266
#endif
6367

6468
struct block_b_cache

ggml/src/ggml-vulkan/vulkan-shaders/types.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ struct block_q6_K_packed16
375375
{
376376
uint16_t ql[QUANT_K_Q6_K/2/2];
377377
uint16_t qh[QUANT_K_Q6_K/4/2];
378-
int8_t scales[QUANT_K_Q6_K/16];
378+
int16_t scales[QUANT_K_Q6_K/16/2];
379379
float16_t d;
380380
};
381381

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
567567

568568
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
569569
// Integer dot mmq performs better with f32 accumulators
570-
if (!f16acc && !coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (is_legacy_quant(tname) || tname == "q2_k" || tname == "q3_k" || tname == "q4_k" | tname == "q5_k")) {
570+
if (!f16acc && !coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (is_legacy_quant(tname) || is_k_quant(tname))) {
571571
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
572572
}
573573
#endif

0 commit comments

Comments
 (0)