Skip to content

Commit 2661a07

Browse files
committed
CUDA: fix Volta FlashAttention logic (ggml-org#11615)
Author : Johannes Gaessler
1 parent 9673530 commit 2661a07

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
561561
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
562562
break;
563563
// case 256:
564-
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
564+
// ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
565565
// break;
566566
default:
567567
GGML_ABORT("fatal error");

ggml/src/ggml-cuda/fattn.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
334334
return;
335335
}
336336

337-
if (!new_mma_available(cc)) {
337+
if (!fp16_mma_available(cc)) {
338338
if (prec == GGML_PREC_DEFAULT) {
339339
if (Q->ne[1] <= 8) {
340340
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
@@ -364,6 +364,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
364364
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
365365
if (cc == GGML_CUDA_CC_VOLTA) {
366366
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
367+
return;
367368
}
368369

369370
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);

0 commit comments

Comments
 (0)