Skip to content

Commit c9382df

Browse files
committed
Add q5_k mmq
1 parent 6d83a8d commit c9382df

File tree

9 files changed

+45
-14
lines changed

9 files changed

+45
-14
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2975,6 +2975,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
29752975
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
29762976
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
29772977
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
2978+
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
29782979
}
29792980
#endif
29802981

@@ -3098,6 +3099,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
30983099
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
30993100
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
31003101
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
3102+
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
31013103
}
31023104
#endif
31033105

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
567567

568568
const uint8_t hm = uint8_t(1 << (iqs / 16));
569569

570-
const vec2 loadd = vec2(data_a[a_offset + ib].d);
570+
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
571571

572572
const uint scidx0 = (is < 4) ? is : (is + 4);
573573
const uint scidx1 = (is < 4) ? is : (is - 4);

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ void main() {
1919
const uint ir = tid % 16;
2020
const uint is = 2 * il;
2121

22-
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
23-
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
22+
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
23+
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
2424

2525
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
2626
const uint qs_idx = 32*il + 2 * ir;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
1414

1515
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
1616
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
17-
vec2 d = vec2(data_a[ib0 + i].d);
18-
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
19-
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
17+
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].dm.x);
18+
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].dm.y);
2019

2120
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
2221
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
215215

216216
const uint8_t hm = uint8_t(1 << (iqs / 16));
217217

218-
const vec2 loadd = vec2(data_a[ib].d);
218+
const vec2 loadd = vec2(data_a[ib].dm);
219219

220220
const uint scidx0 = (is < 4) ? is : (is + 4);
221221
const uint scidx1 = (is < 4) ? is : (is - 4);

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,8 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
346346
#endif // MMQ_SHMEM
347347
#endif
348348

349-
#if defined(DATA_A_Q4_K)
350-
// 4-byte loads for Q4_K blocks (144 bytes)
349+
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
350+
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
351351
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
352352
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
353353
}
@@ -361,10 +361,19 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
361361
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
362362

363363
// Repack 2x4 quants into one int
364+
#if defined(DATA_A_Q4_K)
364365
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
365366
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
366367

367368
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
369+
#else // defined(DATA_A_Q5_K)
370+
const uint qh_idx = iqs * QUANT_R_MMQ;
371+
const uint qh_shift = iqs_k / 8;
372+
373+
buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
374+
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
375+
#endif
376+
368377

369378
if (iqs == 0) {
370379
// Scale index
@@ -384,7 +393,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
384393
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
385394
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
386395

387-
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
396+
[[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {
388397
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
389398
}
390399
}
@@ -393,7 +402,11 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
393402
int32_t q_sum = 0;
394403

395404
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
405+
#if defined(DATA_A_Q4_K)
396406
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
407+
#else // defined(DATA_A_Q5_K)
408+
const int32_t qs_a = cache_a[ib_a].qs[iqs];
409+
#endif
397410

398411
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
399412
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct block_a_cache {
2727
#elif defined(DATA_A_Q8_0)
2828
#define QUANT_R_MMQ 1
2929
// AMD likes 4, Intel likes 1 and Nvidia likes 2
30-
#define BK_STEP 4
30+
#define BK_STEP 1
3131
struct block_a_cache {
3232
int32_t qs[32/4];
3333
FLOAT_TYPE dm;
@@ -51,6 +51,14 @@ struct block_a_cache {
5151
uint32_t qs[4];
5252
FLOAT_TYPE_VEC2 dm;
5353
};
54+
#elif defined(DATA_A_Q5_K)
55+
#define QUANT_R_MMQ 1
56+
// AMD likes 4, Intel likes 1 and Nvidia likes 2
57+
#define BK_STEP 4
58+
struct block_a_cache {
59+
int32_t qs[8];
60+
FLOAT_TYPE_VEC2 dm;
61+
};
5462
#endif
5563

5664
struct block_b_cache

ggml/src/ggml-vulkan/vulkan-shaders/types.glsl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,20 +325,28 @@ struct block_q4_K_packed128
325325

326326
struct block_q5_K
327327
{
328-
f16vec2 d;
328+
f16vec2 dm;
329329
uint8_t scales[12];
330330
uint8_t qh[QUANT_K_Q5_K/8];
331331
uint8_t qs[QUANT_K_Q5_K/2];
332332
};
333333

334334
struct block_q5_K_packed16
335335
{
336-
f16vec2 d;
336+
f16vec2 dm;
337337
uint16_t scales[12/2];
338338
uint16_t qh[QUANT_K_Q5_K/8/2];
339339
uint16_t qs[QUANT_K_Q5_K/2/2];
340340
};
341341

342+
struct block_q5_K_packed32
343+
{
344+
f16vec2 dm;
345+
uint32_t scales[12/4];
346+
uint32_t qh[QUANT_K_Q5_K/8/4];
347+
uint32_t qs[QUANT_K_Q5_K/2/4];
348+
};
349+
342350
struct block_q5_K_packed128
343351
{
344352
uvec4 q5k[11];
@@ -349,6 +357,7 @@ struct block_q5_K_packed128
349357
#define QUANT_R 1
350358
#define A_TYPE block_q5_K
351359
#define A_TYPE_PACKED16 block_q5_K_packed16
360+
#define A_TYPE_PACKED32 block_q5_K_packed32
352361
#define DATA_A_QUANT_K
353362
#endif
354363

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
567567

568568
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
569569
// Integer dot mmq performs better with f32 accumulators
570-
if (!f16acc && !coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (is_legacy_quant(tname) || tname == "q2_k" || tname == "q3_k" || tname == "q4_k")) {
570+
if (!f16acc && !coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (is_legacy_quant(tname) || tname == "q2_k" || tname == "q3_k" || tname == "q4_k" | tname == "q5_k")) {
571571
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
572572
}
573573
#endif

0 commit comments

Comments
 (0)