Skip to content

Commit 5f8fd1e

Browse files
committed
vulkan: add q8_1_x4 type with 128-bit alignment, use in mul_mat_vecq shader
1 parent f6c2124 commit 5f8fd1e

File tree

5 files changed

+108
-39
lines changed

5 files changed

+108
-39
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ struct vk_device_struct {
452452

453453
vk_pipeline pipeline_matmul_split_k_reduce;
454454
vk_pipeline pipeline_quantize_q8_1;
455+
vk_pipeline pipeline_quantize_q8_1_x4;
455456

456457
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
457458
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
@@ -3005,8 +3006,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
30053006

30063007
if (device->subgroup_clustered && device->subgroup_require_full_support) {
30073008
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_subgroup_len, quantize_q8_1_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
3009+
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
30083010
} else {
30093011
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
3012+
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
30103013
}
30113014

30123015
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -5548,20 +5551,20 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
55485551
ggml_vk_sync_buffers(ctx, subctx);
55495552
}
55505553

5551-
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
5554+
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type, bool use_x4_blocks) {
55525555
switch(type) {
55535556
case GGML_TYPE_Q8_1:
5554-
return ctx->device->pipeline_quantize_q8_1;
5557+
return use_x4_blocks ? ctx->device->pipeline_quantize_q8_1_x4 : ctx->device->pipeline_quantize_q8_1;
55555558
default:
55565559
std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
55575560
GGML_ABORT("fatal error");
55585561
}
55595562
}
55605563

5561-
static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
5564+
static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne, bool use_x4_blocks = false) {
55625565
VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
55635566

5564-
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5567+
vk_pipeline pipeline = use_x4_blocks ? ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true) : ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false);
55655568

55665569
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
55675570
ggml_vk_sync_buffers(ctx, subctx);
@@ -5681,7 +5684,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
56815684
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
56825685

56835686
if (quantize_y) {
5684-
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5687+
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false);
56855688
}
56865689

56875690
if (dryrun) {
@@ -5877,16 +5880,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
58775880
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
58785881

58795882
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
5880-
5881-
const uint64_t x_ne = ne01 * ne00;
5882-
const uint64_t y_ne = ne11 * ne10;
5883-
const uint64_t d_ne = ne11 * ne01;
5884-
5885-
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
5886-
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
5887-
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
5888-
const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
5889-
const uint64_t d_sz = sizeof(float) * d_ne;
5883+
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
58905884

58915885
vk_pipeline to_fp16_vk_0 = nullptr;
58925886
vk_pipeline to_fp16_vk_1 = nullptr;
@@ -5899,8 +5893,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
58995893
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
59005894
}
59015895

5902-
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
5903-
59045896
// Check for mmq first
59055897
vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11, ne20, ne00) : nullptr;
59065898
vk_pipeline to_q8_1 = nullptr;
@@ -5912,7 +5904,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
59125904
}
59135905

59145906
if (quantize_y) {
5915-
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5907+
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
59165908
}
59175909

59185910
const bool qx_needs_dequant = x_non_contig;
@@ -5925,6 +5917,16 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
59255917
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
59265918
GGML_ASSERT(dmmv != nullptr);
59275919

5920+
const uint64_t x_ne = ne01 * ne00;
5921+
const uint64_t y_ne = ne11 * ne10;
5922+
const uint64_t d_ne = ne11 * ne01;
5923+
5924+
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
5925+
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
5926+
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
5927+
const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
5928+
const uint64_t d_sz = sizeof(float) * d_ne;
5929+
59285930
if (dryrun) {
59295931
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
59305932
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
@@ -5937,7 +5939,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
59375939
ctx->prealloc_size_x = x_sz_upd;
59385940
}
59395941
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
5940-
ctx->prealloc_size_y = y_sz_upd;
5942+
ctx->prealloc_size_y = CEIL_DIV(y_sz_upd, 128) * 128;
59415943
}
59425944

59435945
// Request descriptor sets
@@ -5982,7 +5984,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
59825984
d_Y = ctx->prealloc_y;
59835985
} else if (quantize_y) {
59845986
d_Y = ctx->prealloc_y;
5985-
GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
5987+
GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 128) * 128);
59865988
} else {
59875989
d_Y = d_Qy;
59885990
y_buf_offset = qy_buf_offset;
@@ -6014,7 +6016,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
60146016
}
60156017
}
60166018
if (quantize_y) {
6017-
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
6019+
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true);
60186020
}
60196021

60206022
// For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#endif
1414

1515
#define MMQ
16-
#define B_TYPE block_q8_1_packed32
16+
#define B_TYPE block_q8_1_x4_packed128
1717

1818
#include "mul_mat_vec_base.comp"
1919

@@ -80,7 +80,7 @@ void reduce_result_grouped(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const i
8080
}
8181
#endif
8282

83-
int32_t cache_b_qs[8];
83+
ivec4 cache_b_qs[2];
8484
vec2 cache_b_ds;
8585

8686
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid_in_group, const uint i) {
@@ -89,10 +89,11 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
8989

9090
// Preload data_b block
9191
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
92-
cache_b_ds = vec2(data_b[b_block_idx].ds);
93-
[[unroll]] for (uint k = 0; k < 8; k++) {
94-
cache_b_qs[k] = data_b[b_block_idx].qs[k];
95-
}
92+
const uint b_block_idx_outer = b_block_idx / 4;
93+
const uint b_block_idx_inner = b_block_idx % 4;
94+
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
95+
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 2];
96+
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 2 + 1];
9697

9798
uint ibi = first_row*p.ncols;
9899
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
@@ -101,19 +102,51 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
101102

102103
int32_t q_sum = 0;
103104
#if QUANT_R == 2
104-
[[unroll]] for (uint k = 0; k < 4; k++) {
105-
const i32vec2 data_a_qs = repack(a_block_idx, k);
106-
q_sum += dotPacked4x8EXT(data_a_qs.x,
107-
cache_b_qs[k]);
108-
q_sum += dotPacked4x8EXT(data_a_qs.y,
109-
cache_b_qs[k + 4]);
110-
}
105+
i32vec2 data_a_qs = repack(a_block_idx, 0);
106+
q_sum += dotPacked4x8EXT(data_a_qs.x,
107+
cache_b_qs[0].x);
108+
q_sum += dotPacked4x8EXT(data_a_qs.y,
109+
cache_b_qs[1].x);
110+
data_a_qs = repack(a_block_idx, 1);
111+
q_sum += dotPacked4x8EXT(data_a_qs.x,
112+
cache_b_qs[0].y);
113+
q_sum += dotPacked4x8EXT(data_a_qs.y,
114+
cache_b_qs[1].y);
115+
data_a_qs = repack(a_block_idx, 2);
116+
q_sum += dotPacked4x8EXT(data_a_qs.x,
117+
cache_b_qs[0].z);
118+
q_sum += dotPacked4x8EXT(data_a_qs.y,
119+
cache_b_qs[1].z);
120+
data_a_qs = repack(a_block_idx, 3);
121+
q_sum += dotPacked4x8EXT(data_a_qs.x,
122+
cache_b_qs[0].w);
123+
q_sum += dotPacked4x8EXT(data_a_qs.y,
124+
cache_b_qs[1].w);
111125
#else
112-
[[unroll]] for (uint k = 0; k < 8; k++) {
113-
const int32_t data_a_qs = repack(a_block_idx, k);
114-
q_sum += dotPacked4x8EXT(data_a_qs,
115-
cache_b_qs[k]);
116-
}
126+
int32_t data_a_qs = repack(a_block_idx, 0);
127+
q_sum += dotPacked4x8EXT(data_a_qs,
128+
cache_b_qs[0].x);
129+
data_a_qs = repack(a_block_idx, 1);
130+
q_sum += dotPacked4x8EXT(data_a_qs,
131+
cache_b_qs[0].y);
132+
data_a_qs = repack(a_block_idx, 2);
133+
q_sum += dotPacked4x8EXT(data_a_qs,
134+
cache_b_qs[0].z);
135+
data_a_qs = repack(a_block_idx, 3);
136+
q_sum += dotPacked4x8EXT(data_a_qs,
137+
cache_b_qs[0].w);
138+
data_a_qs = repack(a_block_idx, 4);
139+
q_sum += dotPacked4x8EXT(data_a_qs,
140+
cache_b_qs[1].x);
141+
data_a_qs = repack(a_block_idx, 5);
142+
q_sum += dotPacked4x8EXT(data_a_qs,
143+
cache_b_qs[1].y);
144+
data_a_qs = repack(a_block_idx, 6);
145+
q_sum += dotPacked4x8EXT(data_a_qs,
146+
cache_b_qs[1].z);
147+
data_a_qs = repack(a_block_idx, 7);
148+
q_sum += dotPacked4x8EXT(data_a_qs,
149+
cache_b_qs[1].w);
117150
#endif
118151

119152
#if QUANT_AUXF == 1

ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ layout(constant_id = 0) const uint GROUP_SIZE = 32;
2323
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
2424

2525
layout (binding = 0) readonly buffer A {vec4 data_a[];};
26+
#ifndef QBLOCK_X4
2627
layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
28+
#else
29+
layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};
30+
#endif
2731

2832
#ifndef USE_SUBGROUPS
2933
shared float shmem[GROUP_SIZE];
@@ -45,6 +49,11 @@ void quantize() {
4549
return;
4650
}
4751

52+
#ifdef QBLOCK_X4
53+
const uint ibx4_outer = ib / 4;
54+
const uint ibx4_inner = ib % 4;
55+
#endif
56+
4857
const uint a_idx = ib * 8 + iqs;
4958

5059
vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f);
@@ -70,7 +79,13 @@ void quantize() {
7079
const float d = amax / 127.0;
7180
const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
7281
vals = round(vals * d_inv);
82+
83+
#ifndef QBLOCK_X4
7384
data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
85+
#else
86+
data_b[ibx4_outer].qs[ibx4_inner * 8 + iqs] = pack32(i8vec4(round(vals)));
87+
#endif
88+
7489
barrier();
7590

7691
// Calculate the sum for each block
@@ -92,7 +107,11 @@ void quantize() {
92107
const float sum = shmem[tid];
93108
#endif
94109

110+
#ifndef QBLOCK_X4
95111
data_b[ib].ds = f16vec2(vec2(d, sum * d));
112+
#else
113+
data_b[ibx4_outer].ds[ibx4_inner] = f16vec2(vec2(d, sum * d));
114+
#endif
96115
}
97116
}
98117

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,18 @@ struct block_q8_1_packed32
207207
int32_t qs[8];
208208
};
209209

210+
// 4 blocks in one to allow 16-byte/128-bit alignment and loads
211+
struct block_q8_1_x4
212+
{
213+
f16vec2 ds[4];
214+
int32_t qs[32];
215+
};
216+
struct block_q8_1_x4_packed128
217+
{
218+
f16vec2 ds[4];
219+
ivec4 qs[8];
220+
};
221+
210222
// K-quants
211223
#define QUANT_K_Q2_K 256
212224

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,9 @@ void process_shaders() {
588588
string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
589589
string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}});
590590

591+
string_to_spv("quantize_q8_1_x4", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}});
592+
string_to_spv("quantize_q8_1_x4_subgroup", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}, {"USE_SUBGROUPS", "1"}});
593+
591594
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
592595

593596
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

0 commit comments

Comments
 (0)