Skip to content

Commit 7c5f8de

Browse files
committed
vulkan: use q8_1_x4 blocks in mul_mmq shader
1 parent 3a57953 commit 7c5f8de

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5763,7 +5763,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
57635763
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
57645764

57655765
if (quantize_y) {
5766-
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false);
5766+
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
57675767
}
57685768

57695769
if (dryrun) {
@@ -5780,7 +5780,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
57805780
ctx->prealloc_size_x = x_sz_upd;
57815781
}
57825782
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
5783-
ctx->prealloc_size_y = y_sz_upd;
5783+
ctx->prealloc_size_y = CEIL_DIV(y_sz_upd, 128) * 128;
57845784
}
57855785
if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
57865786
ctx->prealloc_size_split_k = split_k_size;
@@ -5871,7 +5871,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
58715871
if (ctx->prealloc_y_need_sync) {
58725872
ggml_vk_sync_buffers(ctx, subctx);
58735873
}
5874-
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
5874+
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true);
58755875
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
58765876
ctx->prealloc_y_last_tensor_used = src1;
58775877
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
2828
#if defined(A_TYPE_PACKED32)
2929
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
3030
#endif
31-
layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
31+
layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
3232
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
3333

3434
#ifdef MUL_MAT_ID
@@ -98,7 +98,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
9898
#endif
9999

100100
#define LOAD_VEC_A (4 * QUANT_R)
101-
#define LOAD_VEC_B 4
101+
#define LOAD_VEC_B 16
102102

103103
#ifdef MUL_MAT_ID
104104
shared u16vec2 row_ids[4096];
@@ -270,15 +270,22 @@ void main() {
270270
const uint iqs = idx & 0x7;
271271
#else
272272
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
273+
const uint ib_outer = ib / 4;
274+
const uint ib_inner = ib % 4;
275+
273276
const uint iqs = loadr_b;
274277
#endif
275278

276279
const uint buf_ib = loadc_b + l;
277280

278281
if (iqs == 0) {
279-
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
282+
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
280283
}
281-
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
284+
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
285+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
286+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
287+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
288+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
282289
}
283290

284291
barrier();

0 commit comments

Comments
 (0)