Skip to content

Commit 7ecd780

Browse files
authored
vulkan: Use fp16 for the flash attention P*V multiplication (#12783)
This is consistent with the ggml-cuda behavior and the mul_mat fallback.
1 parent 7538246 commit 7ecd780

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,11 @@ void main() {
330330
// resize eM by using smear/reduce
331331
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
332332

333-
O = eMdiag * O;
333+
// multiply with fp16 accumulation, then add to O.
334+
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
335+
PV = coopMatMulAdd(P_A, V, PV);
334336

335-
O = coopMatMulAdd(P_A, V, O);
337+
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
336338
}
337339

338340
// If there is split_k, then the split_k resolve shader does the final

0 commit comments

Comments
 (0)