Skip to content

Commit 68d793b

Browse files
JohannesGaesslerggerganov
authored andcommitted
no ncols == 64
1 parent cca6d02 commit 68d793b

File tree

1 file changed

+1
-9
lines changed

1 file changed

+1
-9
lines changed

ggml-cuda/fattn.cu

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -625,9 +625,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
625625
}
626626

627627
int cols_per_block;
628-
if (false && Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) {
629-
cols_per_block = 64;
630-
} else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
628+
if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
631629
cols_per_block = 32;
632630
} else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) {
633631
cols_per_block = 16;
@@ -645,7 +643,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
645643
FATTN_SWITCH_CASE(64, 8, nwarps);
646644
FATTN_SWITCH_CASE(64, 16, nwarps);
647645
FATTN_SWITCH_CASE(64, 32, nwarps);
648-
FATTN_SWITCH_CASE(64, 64, nwarps);
649646
default:
650647
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
651648
GGML_ASSERT(false);
@@ -655,7 +652,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
655652
// FATTN_SWITCH_CASE(80, 8, nwarps);
656653
FATTN_SWITCH_CASE(80, 16, nwarps);
657654
FATTN_SWITCH_CASE(80, 32, nwarps);
658-
// FATTN_SWITCH_CASE(80, 64, nwarps);
659655
default:
660656
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
661657
GGML_ASSERT(false);
@@ -665,7 +661,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
665661
FATTN_SWITCH_CASE(96, 8, nwarps);
666662
FATTN_SWITCH_CASE(96, 16, nwarps);
667663
FATTN_SWITCH_CASE(96, 32, nwarps);
668-
FATTN_SWITCH_CASE(96, 64, nwarps);
669664
default:
670665
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
671666
GGML_ASSERT(false);
@@ -675,7 +670,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
675670
// FATTN_SWITCH_CASE(112, 8, nwarps);
676671
FATTN_SWITCH_CASE(112, 16, nwarps);
677672
FATTN_SWITCH_CASE(112, 32, nwarps);
678-
// FATTN_SWITCH_CASE(112, 64, nwarps);
679673
default:
680674
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
681675
GGML_ASSERT(false);
@@ -685,7 +679,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
685679
FATTN_SWITCH_CASE(128, 8, nwarps);
686680
FATTN_SWITCH_CASE(128, 16, nwarps);
687681
FATTN_SWITCH_CASE(128, 32, nwarps);
688-
// FATTN_SWITCH_CASE(128, 64, nwarps);
689682
default:
690683
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
691684
GGML_ASSERT(false);
@@ -695,7 +688,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
695688
FATTN_SWITCH_CASE(256, 8, nwarps);
696689
FATTN_SWITCH_CASE(256, 16, nwarps);
697690
FATTN_SWITCH_CASE(256, 32, nwarps);
698-
// FATTN_SWITCH_CASE(256, 64, nwarps);
699691
default:
700692
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
701693
GGML_ASSERT(false);

0 commit comments

Comments
 (0)