Skip to content

Commit 1309d7d

Browse files
committed
Use 32-bit accumulators for integer dot matmul
1 parent e978f66 commit 1309d7d

File tree

3 files changed

+24
-27
lines changed

3 files changed

+24
-27
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2448,6 +2448,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
24482448
l_warptile_id, m_warptile_id, s_warptile_id,
24492449
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
24502450
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
2451+
l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k,
24512452
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
24522453
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
24532454
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
@@ -2517,6 +2518,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
25172518
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
25182519
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
25192520

2521+
// K-quants use even more registers, mitigate by setting WMITER to 1
2522+
l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
2523+
m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
2524+
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
2525+
25202526
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
25212527
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
25222528
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
@@ -2915,15 +2921,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
29152921

29162922
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
29172923
if (device->mul_mat ## ID ## _l[TYPE]) { \
2918-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
29192924
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
29202925
} \
29212926
if (device->mul_mat ## ID ## _m[TYPE]) { \
2922-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
29232927
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
29242928
} \
29252929
if (device->mul_mat ## ID ## _s[TYPE]) { \
2926-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
29272930
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
29282931
} \
29292932

@@ -2969,7 +2972,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
29692972
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
29702973
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
29712974

2972-
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2975+
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
29732976
}
29742977
#endif
29752978

@@ -3090,7 +3093,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
30903093
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
30913094
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
30923095

3093-
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
3096+
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
30943097
}
30953098
#endif
30963099

@@ -4933,7 +4936,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
49334936

49344937
// MMQ
49354938
if (src1_type == GGML_TYPE_Q8_1) {
4936-
vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
4939+
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
49374940

49384941
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
49394942
return nullptr;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,10 @@ void main() {
183183
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
184184
#endif
185185

186-
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN / 2];
186+
ACC_TYPE sums[WMITER * TM * WNITER * TN];
187187

188-
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
189-
sums[i] = ACC_TYPE_VEC2(0.0f);
188+
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
189+
sums[i] = ACC_TYPE(0.0f);
190190
}
191191

192192
for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
@@ -240,12 +240,11 @@ void main() {
240240
block_b_to_registers(ib);
241241

242242
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
243-
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
244-
const uint cache_a_idx = wsir * TM + cr * 2;
245-
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr;
243+
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
244+
const uint cache_a_idx = wsir * TM + cr;
245+
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
246246

247-
sums[sums_idx].x += mmq_dot_product(cache_a_idx);
248-
sums[sums_idx].y += mmq_dot_product(cache_a_idx + 1);
247+
sums[sums_idx] += mmq_dot_product(cache_a_idx);
249248
}
250249
}
251250
}
@@ -274,21 +273,15 @@ void main() {
274273

275274
const u16vec2 row_idx = row_ids[row_i - ic * BN];
276275
#endif // MUL_MAT_ID
277-
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
278-
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
276+
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
277+
const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
279278
#ifdef MUL_MAT_ID
280-
if (dr_warp + 2 * cr < p.M) {
281-
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
282-
}
283-
if (dr_warp + 2 * cr + 1 < p.M) {
284-
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
279+
if (dr_warp + cr < p.M) {
280+
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
285281
}
286282
#else
287-
if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
288-
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
289-
}
290-
if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
291-
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
283+
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
284+
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
292285
}
293286
#endif // MUL_MAT_ID
294287
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
566566
}
567567

568568
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
569-
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (is_legacy_quant(tname) || tname == "q2_k")) {
569+
// Integer dot mmq performs better with f32 accumulators
570+
if (!f16acc && !coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (is_legacy_quant(tname) || tname == "q2_k")) {
570571
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
571572
}
572573
#endif

0 commit comments

Comments
 (0)