Skip to content

Commit 5d4ab04

Browse files
committed
Remove the Q->ne[1] > 8 check
1 parent 29debe1 commit 5d4ab04

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,10 +578,10 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
578578
return;
579579
}
580580

581+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
581582
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
582583
constexpr int cols_per_block = 8;
583584
switch (Q->ne[0]) {
584-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
585585
case 64:
586586
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
587587
break;
@@ -594,13 +594,13 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
594594
case 256:
595595
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
596596
break;
597-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
598597
default:
599598
GGML_ABORT("fatal error");
600599
break;
601600
}
602601
return;
603602
}
603+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
604604

605605
if (Q->ne[1] <= 32) {
606606
constexpr int cols_per_block = 16;

ggml/src/ggml-cuda/fattn.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
254254

255255
if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
256256
#if defined(GGML_HIP_ROCWMMA_FATTN)
257-
if (fp16_mma_available(cc) && dst->src[0]->ne[1] > 8) {
257+
if (fp16_mma_available(cc)) {
258258
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
259259
return;
260260
}

0 commit comments

Comments
 (0)