Skip to content

Commit 6d83a8d

Browse files
committed
Add q3_k mmq
1 parent 5148d4a commit 6d83a8d

File tree

4 files changed

+76
-1
lines changed

4 files changed

+76
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2973,6 +2973,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
29732973
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
29742974

29752975
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
2976+
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
29762977
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
29772978
}
29782979
#endif
@@ -3095,6 +3096,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
30953096
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
30963097

30973098
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
3099+
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
30983100
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
30993101
}
31003102
#endif

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,73 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
279279
#endif // MMQ_SHMEM
280280
#endif
281281

282+
#if defined(DATA_A_Q3_K)
283+
// 2-byte loads for Q3_K blocks (110 bytes)
284+
#ifdef MMQ_SHMEM
285+
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
286+
const uint ib_k = ib / 8;
287+
const uint hm_idx = iqs * QUANT_R_MMQ;
288+
const uint iqs_k = (ib % 8) * 8 + hm_idx;
289+
290+
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
291+
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
292+
const uint hm_shift = iqs_k / 8;
293+
294+
// Repack 2x4 quants into one int
295+
// Add the 3rd bit instead of subtracting it to allow packing the quants
296+
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
297+
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
298+
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
299+
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
300+
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
301+
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
302+
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
303+
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
304+
buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
305+
(pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
306+
307+
if (iqs == 0) {
308+
const uint is = iqs_k / 4;
309+
const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
310+
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
311+
312+
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
313+
}
314+
}
315+
316+
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
317+
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
318+
319+
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
320+
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
321+
}
322+
}
323+
324+
ACC_TYPE mmq_dot_product(const uint ib_a) {
325+
float result = 0.0;
326+
int32_t q_sum = 0;
327+
328+
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
329+
// Subtract 4 from the quants to correct the 3rd bit offset
330+
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
331+
332+
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
333+
}
334+
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
335+
q_sum = 0;
336+
337+
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
338+
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
339+
340+
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
341+
}
342+
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
343+
344+
return ACC_TYPE(cache_b.ds.x * result);
345+
}
346+
#endif // MMQ_SHMEM
347+
#endif
348+
282349
#if defined(DATA_A_Q4_K)
283350
// 4-byte loads for Q4_K blocks (144 bytes)
284351
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ struct block_a_cache {
3939
u8vec2 scales;
4040
FLOAT_TYPE_VEC2 dm;
4141
};
42+
#elif defined(DATA_A_Q3_K)
43+
#define QUANT_R_MMQ 2
44+
struct block_a_cache {
45+
uint32_t qs[4];
46+
FLOAT_TYPE_VEC2 d_scales;
47+
};
4248
#elif defined(DATA_A_Q4_K)
4349
#define QUANT_R_MMQ 2
4450
struct block_a_cache {

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
567567

568568
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
569569
// Integer dot mmq performs better with f32 accumulators
570-
if (!f16acc && !coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (is_legacy_quant(tname) || tname == "q2_k" || tname == "q4_k")) {
570+
if (!f16acc && !coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (is_legacy_quant(tname) || tname == "q2_k" || tname == "q3_k" || tname == "q4_k")) {
571571
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
572572
}
573573
#endif

0 commit comments

Comments
 (0)