Skip to content

Commit e83d158

Browse files
committed
vulkan: add q8_1_x4 type with 128-bit alignment, use in mul_mat_vecq shader
1 parent 1353011 commit e83d158

File tree

5 files changed

+108
-39
lines changed

5 files changed

+108
-39
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ struct vk_device_struct {
428428

429429
vk_pipeline pipeline_matmul_split_k_reduce;
430430
vk_pipeline pipeline_quantize_q8_1;
431+
vk_pipeline pipeline_quantize_q8_1_x4;
431432

432433
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
433434
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
@@ -2900,8 +2901,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
29002901

29012902
if (device->subgroup_clustered && device->subgroup_require_full_support) {
29022903
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_subgroup_len, quantize_q8_1_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
2904+
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
29032905
} else {
29042906
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2907+
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
29052908
}
29062909

29072910
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -5352,20 +5355,20 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
53525355
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
53535356
}
53545357

5355-
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
5358+
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type, bool use_x4_blocks) {
53565359
switch(type) {
53575360
case GGML_TYPE_Q8_1:
5358-
return ctx->device->pipeline_quantize_q8_1;
5361+
return use_x4_blocks ? ctx->device->pipeline_quantize_q8_1_x4 : ctx->device->pipeline_quantize_q8_1;
53595362
default:
53605363
std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
53615364
GGML_ABORT("fatal error");
53625365
}
53635366
}
53645367

5365-
static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
5368+
static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne, bool use_x4_blocks = false) {
53665369
VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
53675370

5368-
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5371+
vk_pipeline pipeline = use_x4_blocks ? ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true) : ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false);
53695372

53705373
ggml_vk_sync_buffers(subctx);
53715374
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
@@ -5485,7 +5488,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
54855488
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
54865489

54875490
if (quantize_y) {
5488-
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5491+
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false);
54895492
}
54905493

54915494
if (dryrun) {
@@ -5653,16 +5656,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
56535656
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
56545657

56555658
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
5656-
5657-
const uint64_t x_ne = ne01 * ne00;
5658-
const uint64_t y_ne = ne11 * ne10;
5659-
const uint64_t d_ne = ne11 * ne01;
5660-
5661-
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
5662-
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
5663-
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
5664-
const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
5665-
const uint64_t d_sz = sizeof(float) * d_ne;
5659+
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
56665660

56675661
vk_pipeline to_fp16_vk_0 = nullptr;
56685662
vk_pipeline to_fp16_vk_1 = nullptr;
@@ -5675,8 +5669,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
56755669
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
56765670
}
56775671

5678-
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
5679-
56805672
// Check for mmq first
56815673
vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11) : nullptr;
56825674
vk_pipeline to_q8_1 = nullptr;
@@ -5688,7 +5680,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
56885680
}
56895681

56905682
if (quantize_y) {
5691-
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5683+
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
56925684
}
56935685

56945686
const bool qx_needs_dequant = x_non_contig;
@@ -5701,6 +5693,16 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
57015693
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
57025694
GGML_ASSERT(dmmv != nullptr);
57035695

5696+
const uint64_t x_ne = ne01 * ne00;
5697+
const uint64_t y_ne = ne11 * ne10;
5698+
const uint64_t d_ne = ne11 * ne01;
5699+
5700+
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
5701+
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
5702+
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
5703+
const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
5704+
const uint64_t d_sz = sizeof(float) * d_ne;
5705+
57045706
if (dryrun) {
57055707
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
57065708
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
@@ -5713,7 +5715,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
57135715
ctx->prealloc_size_x = x_sz_upd;
57145716
}
57155717
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
5716-
ctx->prealloc_size_y = y_sz_upd;
5718+
ctx->prealloc_size_y = CEIL_DIV(y_sz_upd, 128) * 128;
57175719
}
57185720

57195721
// Request descriptor sets
@@ -5758,7 +5760,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
57585760
d_Y = ctx->prealloc_y;
57595761
} else if (quantize_y) {
57605762
d_Y = ctx->prealloc_y;
5761-
GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
5763+
GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 128) * 128);
57625764
} else {
57635765
d_Y = d_Qy;
57645766
y_buf_offset = qy_buf_offset;
@@ -5774,7 +5776,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
57745776
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
57755777
}
57765778
if (quantize_y) {
5777-
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
5779+
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true);
57785780
}
57795781

57805782
// For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#endif
1414

1515
#define MMQ
16-
#define B_TYPE block_q8_1_packed32
16+
#define B_TYPE block_q8_1_x4_packed128
1717

1818
#include "mul_mat_vec_base.comp"
1919

@@ -80,7 +80,7 @@ void reduce_result_grouped(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const i
8080
}
8181
#endif
8282

83-
int32_t cache_b_qs[8];
83+
ivec4 cache_b_qs[2];
8484
vec2 cache_b_ds;
8585

8686
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid_in_group, const uint i) {
@@ -89,10 +89,11 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
8989

9090
// Preload data_b block
9191
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
92-
cache_b_ds = vec2(data_b[b_block_idx].ds);
93-
[[unroll]] for (uint k = 0; k < 8; k++) {
94-
cache_b_qs[k] = data_b[b_block_idx].qs[k];
95-
}
92+
const uint b_block_idx_outer = b_block_idx / 4;
93+
const uint b_block_idx_inner = b_block_idx % 4;
94+
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
95+
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 2];
96+
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 2 + 1];
9697

9798
uint ibi = first_row*p.ncols;
9899
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
@@ -101,19 +102,51 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
101102

102103
int32_t q_sum = 0;
103104
#if QUANT_R == 2
104-
[[unroll]] for (uint k = 0; k < 4; k++) {
105-
const i32vec2 data_a_qs = repack(a_block_idx, k);
106-
q_sum += dotPacked4x8EXT(data_a_qs.x,
107-
cache_b_qs[k]);
108-
q_sum += dotPacked4x8EXT(data_a_qs.y,
109-
cache_b_qs[k + 4]);
110-
}
105+
i32vec2 data_a_qs = repack(a_block_idx, 0);
106+
q_sum += dotPacked4x8EXT(data_a_qs.x,
107+
cache_b_qs[0].x);
108+
q_sum += dotPacked4x8EXT(data_a_qs.y,
109+
cache_b_qs[1].x);
110+
data_a_qs = repack(a_block_idx, 1);
111+
q_sum += dotPacked4x8EXT(data_a_qs.x,
112+
cache_b_qs[0].y);
113+
q_sum += dotPacked4x8EXT(data_a_qs.y,
114+
cache_b_qs[1].y);
115+
data_a_qs = repack(a_block_idx, 2);
116+
q_sum += dotPacked4x8EXT(data_a_qs.x,
117+
cache_b_qs[0].z);
118+
q_sum += dotPacked4x8EXT(data_a_qs.y,
119+
cache_b_qs[1].z);
120+
data_a_qs = repack(a_block_idx, 3);
121+
q_sum += dotPacked4x8EXT(data_a_qs.x,
122+
cache_b_qs[0].w);
123+
q_sum += dotPacked4x8EXT(data_a_qs.y,
124+
cache_b_qs[1].w);
111125
#else
112-
[[unroll]] for (uint k = 0; k < 8; k++) {
113-
const int32_t data_a_qs = repack(a_block_idx, k);
114-
q_sum += dotPacked4x8EXT(data_a_qs,
115-
cache_b_qs[k]);
116-
}
126+
int32_t data_a_qs = repack(a_block_idx, 0);
127+
q_sum += dotPacked4x8EXT(data_a_qs,
128+
cache_b_qs[0].x);
129+
data_a_qs = repack(a_block_idx, 1);
130+
q_sum += dotPacked4x8EXT(data_a_qs,
131+
cache_b_qs[0].y);
132+
data_a_qs = repack(a_block_idx, 2);
133+
q_sum += dotPacked4x8EXT(data_a_qs,
134+
cache_b_qs[0].z);
135+
data_a_qs = repack(a_block_idx, 3);
136+
q_sum += dotPacked4x8EXT(data_a_qs,
137+
cache_b_qs[0].w);
138+
data_a_qs = repack(a_block_idx, 4);
139+
q_sum += dotPacked4x8EXT(data_a_qs,
140+
cache_b_qs[1].x);
141+
data_a_qs = repack(a_block_idx, 5);
142+
q_sum += dotPacked4x8EXT(data_a_qs,
143+
cache_b_qs[1].y);
144+
data_a_qs = repack(a_block_idx, 6);
145+
q_sum += dotPacked4x8EXT(data_a_qs,
146+
cache_b_qs[1].z);
147+
data_a_qs = repack(a_block_idx, 7);
148+
q_sum += dotPacked4x8EXT(data_a_qs,
149+
cache_b_qs[1].w);
117150
#endif
118151

119152
#if QUANT_AUXF == 1

ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ layout(constant_id = 0) const uint GROUP_SIZE = 32;
2323
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
2424

2525
layout (binding = 0) readonly buffer A {vec4 data_a[];};
26+
#ifndef QBLOCK_X4
2627
layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
28+
#else
29+
layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};
30+
#endif
2731

2832
#ifndef USE_SUBGROUPS
2933
shared float shmem[GROUP_SIZE];
@@ -45,6 +49,11 @@ void quantize() {
4549
return;
4650
}
4751

52+
#ifdef QBLOCK_X4
53+
const uint ibx4_outer = ib / 4;
54+
const uint ibx4_inner = ib % 4;
55+
#endif
56+
4857
const uint a_idx = ib * 8 + iqs;
4958

5059
vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f);
@@ -70,7 +79,13 @@ void quantize() {
7079
const float d = amax / 127.0;
7180
const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
7281
vals = round(vals * d_inv);
82+
83+
#ifndef QBLOCK_X4
7384
data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
85+
#else
86+
data_b[ibx4_outer].qs[ibx4_inner * 8 + iqs] = pack32(i8vec4(round(vals)));
87+
#endif
88+
7489
barrier();
7590

7691
// Calculate the sum for each block
@@ -92,7 +107,11 @@ void quantize() {
92107
const float sum = shmem[tid];
93108
#endif
94109

110+
#ifndef QBLOCK_X4
95111
data_b[ib].ds = f16vec2(vec2(d, sum * d));
112+
#else
113+
data_b[ibx4_outer].ds[ibx4_inner] = f16vec2(vec2(d, sum * d));
114+
#endif
96115
}
97116
}
98117

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,18 @@ struct block_q8_1_packed32
207207
int32_t qs[8];
208208
};
209209

210+
// 4 blocks in one to allow 16-byte/128-bit alignment and loads
211+
struct block_q8_1_x4
212+
{
213+
f16vec2 ds[4];
214+
int32_t qs[32];
215+
};
216+
struct block_q8_1_x4_packed128
217+
{
218+
f16vec2 ds[4];
219+
ivec4 qs[8];
220+
};
221+
210222
// K-quants
211223
#define QUANT_K_Q2_K 256
212224

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,9 @@ void process_shaders() {
580580
string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
581581
string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}});
582582

583+
string_to_spv("quantize_q8_1_x4", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}});
584+
string_to_spv("quantize_q8_1_x4_subgroup", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}, {"USE_SUBGROUPS", "1"}});
585+
583586
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
584587

585588
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

0 commit comments

Comments
 (0)