Skip to content

Commit 5148d4a

Browse files
committed
Add q4_k mmq
1 parent 1309d7d commit 5148d4a

File tree

9 files changed

+79
-14
lines changed

9 files changed

+79
-14
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2973,6 +2973,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
29732973
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
29742974

29752975
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
2976+
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
29762977
}
29772978
#endif
29782979

@@ -3094,6 +3095,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
30943095
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
30953096

30963097
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
3098+
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
30973099
}
30983100
#endif
30993101

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
529529
const uint is = 2 * n + b; // 0..7
530530
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
531531

532-
const vec2 loadd = vec2(data_a[a_offset + ib].d);
532+
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
533533

534534
const uint scidx0 = (is < 4) ? is : (is + 4);
535535
const uint scidx1 = (is < 4) ? is : (is - 4);

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ void main() {
2020
const uint is = 2 * il;
2121
const uint n = 4;
2222

23-
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
24-
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
23+
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
24+
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
2525

2626
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
2727
const uint qs_idx = 32*il + n * ir;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
1414

1515
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
1616
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
17-
vec2 d = vec2(data_a[ib0 + i].d);
18-
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
19-
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
17+
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].dm.x);
18+
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].dm.y);
2019

2120
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
2221
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
179179
const uint is = 2 * n + b; // 0..7
180180
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
181181

182-
const vec2 loadd = vec2(data_a[ib].d);
182+
const vec2 loadd = vec2(data_a[ib].dm);
183183

184184
const uint scidx0 = (is < 4) ? is : (is + 4);
185185
const uint scidx1 = (is < 4) ? is : (is - 4);

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, cons
233233
#ifdef MMQ_SHMEM
234234
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
235235
const uint ib_k = ib / 8;
236-
const uint iqs_k = (ib % 8) * 8 + iqs * 4;
236+
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
237237

238238
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
239239
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
@@ -279,6 +279,63 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
279279
#endif // MMQ_SHMEM
280280
#endif
281281

282+
#if defined(DATA_A_Q4_K)
283+
// 4-byte loads for Q4_K blocks (144 bytes)
284+
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
285+
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
286+
}
287+
288+
#ifdef MMQ_SHMEM
289+
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
290+
const uint ib_k = ib / 8;
291+
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
292+
293+
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
294+
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
295+
296+
// Repack 2x4 quants into one int
297+
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
298+
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
299+
300+
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
301+
302+
if (iqs == 0) {
303+
// Scale index
304+
const uint is = iqs_k / 8;
305+
u8vec2 scale_dm;
306+
if (is < 4) {
307+
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
308+
} else {
309+
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
310+
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
311+
}
312+
313+
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
314+
}
315+
}
316+
317+
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
318+
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
319+
320+
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
321+
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
322+
}
323+
}
324+
325+
ACC_TYPE mmq_dot_product(const uint ib_a) {
326+
int32_t q_sum = 0;
327+
328+
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
329+
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
330+
331+
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
332+
}
333+
334+
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
335+
}
336+
#endif // MMQ_SHMEM
337+
#endif
338+
282339
#ifdef MMQ_SHMEM
283340
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
284341
const uint ib_outer = ib / 4;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,25 @@ struct block_a_cache {
2626
};
2727
#elif defined(DATA_A_Q8_0)
2828
#define QUANT_R_MMQ 1
29+
// AMD likes 4, Intel likes 1 and Nvidia likes 2
30+
#define BK_STEP 4
2931
struct block_a_cache {
3032
int32_t qs[32/4];
3133
FLOAT_TYPE dm;
3234
};
3335
#elif defined(DATA_A_Q2_K)
3436
#define QUANT_R_MMQ 4
35-
struct block_a_cache
36-
{
37+
struct block_a_cache {
3738
uint32_t qs[2];
3839
u8vec2 scales;
3940
FLOAT_TYPE_VEC2 dm;
4041
};
42+
#elif defined(DATA_A_Q4_K)
43+
#define QUANT_R_MMQ 2
44+
struct block_a_cache {
45+
uint32_t qs[4];
46+
FLOAT_TYPE_VEC2 dm;
47+
};
4148
#endif
4249

4350
struct block_b_cache

ggml/src/ggml-vulkan/vulkan-shaders/types.glsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,21 +288,21 @@ struct block_q3_K_packed16
288288

289289
struct block_q4_K
290290
{
291-
f16vec2 d;
291+
f16vec2 dm;
292292
uint8_t scales[3*QUANT_K_Q4_K/64];
293293
uint8_t qs[QUANT_K_Q4_K/2];
294294
};
295295

296296
struct block_q4_K_packed16
297297
{
298-
f16vec2 d;
298+
f16vec2 dm;
299299
uint16_t scales[3*QUANT_K_Q4_K/64/2];
300300
uint16_t qs[QUANT_K_Q4_K/2/2];
301301
};
302302

303303
struct block_q4_K_packed32
304304
{
305-
f16vec2 d;
305+
f16vec2 dm;
306306
uint32_t scales[3*QUANT_K_Q4_K/64/4];
307307
uint32_t qs[QUANT_K_Q4_K/2/4];
308308
};

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
567567

568568
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
569569
// Integer dot mmq performs better with f32 accumulators
570-
if (!f16acc && !coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (is_legacy_quant(tname) || tname == "q2_k")) {
570+
if (!f16acc && !coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (is_legacy_quant(tname) || tname == "q2_k" || tname == "q4_k")) {
571571
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
572572
}
573573
#endif

0 commit comments

Comments
 (0)