Skip to content

Commit 3238b14

Browse files
authored
vulkan: Fix data race/hang in scalar/cm1 flash attention (ggml-org#17887)
1 parent 4722671 commit 3238b14

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ void main() {
256256
barrier();
257257
}
258258

259+
// prevent race on tmpsh
260+
barrier();
261+
259262
// reduce across threads
260263

261264
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@ void main() {
302302
barrier();
303303
}
304304

305+
// prevent race on tmpsh
306+
barrier();
307+
305308
// reduce across threads
306309

307310
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];

0 commit comments

Comments
 (0)