Skip to content

Commit 129a0f1

Browse files
committed
vulkan: allow unclamped loads in coopmat2 mul_mat_id shader
1 parent 6491d6e commit 129a0f1

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -414,17 +414,31 @@ void main() {
414414
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
415415
}
416416

417-
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
418-
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
417+
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
418+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
419+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
419420

420-
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
421+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
421422
#ifdef MUL_MAT_ID
422-
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
423+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
423424
#else
424-
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
425+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
425426
#endif
426427

427-
sum = coopMatMulAdd(mat_a, mat_b, sum);
428+
sum = coopMatMulAdd(mat_a, mat_b, sum);
429+
} else {
430+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
431+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
432+
433+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
434+
#ifdef MUL_MAT_ID
435+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
436+
#else
437+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
438+
#endif
439+
440+
sum = coopMatMulAdd(mat_a, mat_b, sum);
441+
}
428442
}
429443

430444
// Convert from ACC_TYPE to D_TYPE

0 commit comments

Comments
 (0)