Skip to content

Commit c37052a

Browse files
authored
vulkan: mul_mat_id coopmat2 optimizations (ggml-org#15546)
* vulkan: mul_mat_id coopmat2 optimizations Add a path for when the tile fits in BN/2, similar to what we have for mul_mat. Only call fetch_scales/store_scales once per QUANT_K block, and once at the beginning in case start_k is not aligned. * Also add a path for BN/4 - worth a couple more percent
1 parent 5c16b9c commit c37052a

File tree

2 files changed

+93
-6
lines changed

2 files changed

+93
-6
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2225,7 +2225,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
22252225
s_mmq_wg_denoms_k = { 32, 64, 1 };
22262226

22272227
// spec constants and tile sizes for quant matmul_id
2228-
l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size };
2228+
l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size };
22292229
m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
22302230
s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
22312231
l_mmqid_wg_denoms = { 128, 128, 1 };

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 92 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,18 +456,105 @@ void main() {
456456

457457
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
458458

459-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
460-
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
461-
462459
uint k_iters = (end_k - start_k + BK - 1) / BK;
463460

464461
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
462+
store_scales(tid);
463+
464+
#ifdef MUL_MAT_ID
465+
if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) {
466+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum;
467+
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
468+
469+
[[dont_unroll]]
470+
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
471+
472+
if ((block_k % QUANT_K) == 0) {
473+
store_scales(tid);
474+
}
475+
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
476+
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
477+
}
478+
479+
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
480+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
481+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
482+
483+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
484+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
485+
486+
sum = coopMatMulAdd(mat_a, mat_b, sum);
487+
} else {
488+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
489+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
490+
491+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
492+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
493+
494+
sum = coopMatMulAdd(mat_a, mat_b, sum);
495+
}
496+
}
497+
498+
// Convert from ACC_TYPE to D_TYPE
499+
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d;
500+
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
501+
502+
// Call callback to store each element, remapping row through shared memory
503+
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
504+
return;
505+
}
506+
if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {
507+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum;
508+
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
509+
510+
[[dont_unroll]]
511+
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
512+
513+
if ((block_k % QUANT_K) == 0) {
514+
store_scales(tid);
515+
}
516+
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
517+
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
518+
}
519+
520+
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
521+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
522+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
523+
524+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
525+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
526+
527+
sum = coopMatMulAdd(mat_a, mat_b, sum);
528+
} else {
529+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
530+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
531+
532+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
533+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
534+
535+
sum = coopMatMulAdd(mat_a, mat_b, sum);
536+
}
537+
}
538+
539+
// Convert from ACC_TYPE to D_TYPE
540+
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d;
541+
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
542+
543+
// Call callback to store each element, remapping row through shared memory
544+
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
545+
return;
546+
}
547+
#endif
548+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
549+
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
465550

466551
[[dont_unroll]]
467552
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
468553

469-
store_scales(tid);
470-
if (block_k + BK < end_k) {
554+
if ((block_k % QUANT_K) == 0) {
555+
store_scales(tid);
556+
}
557+
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
471558
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
472559
}
473560

0 commit comments

Comments
 (0)