Skip to content

Commit fb34984

Browse files
authored
vulkan: Handle FA with all -inf mask values (ggml-org#16447)
1 parent 6de8ed7 commit fb34984

File tree

4 files changed

+8
-4
lines changed

4 files changed

+8
-4
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ void main() {
345345

346346
float Lfrcp[Br];
347347
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
348-
Lfrcp[r] = 1.0 / Lf[r];
348+
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
349349
}
350350

351351
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ void main() {
380380

381381
float Lfrcp[rows_per_thread];
382382
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
383-
Lfrcp[r] = 1.0 / Lf[r];
383+
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
384384
}
385385

386386
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,11 @@ void main() {
121121
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
122122

123123
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
124+
#if defined(ACC_TYPE_MAX)
125+
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-ACC_TYPE_MAX / ACC_TYPE(2));
126+
#else
124127
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
128+
#endif
125129

126130
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
127131

@@ -294,7 +298,7 @@ void main() {
294298

295299
[[unroll]]
296300
for (int k = 0; k < Ldiag.length(); ++k) {
297-
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
301+
Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);
298302
}
299303

300304
O = Ldiag*O;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ void main() {
9191
L = L*ms + vs;
9292
}
9393

94-
L = 1.0 / L;
94+
L = (L == 0.0) ? 0.0 : 1.0 / L;
9595

9696
// D dimension is split across workgroups in the y dimension
9797
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;

0 commit comments

Comments
 (0)