Skip to content

Commit 269374e

Browse files
JohannesGaesslerggerganov
authored andcommitted
adjust kernel selection logic
1 parent 81da919 commit 269374e

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

ggml-cuda/fattn.cu

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -579,17 +579,15 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
579579
return;
580580
}
581581

582-
int cols_per_block = 16;
583-
if (Q->ne[0] % 32 == 0) {
584-
if (Q->ne[1] >= 128 && Q->ne[0] <= 128) {
585-
cols_per_block = 64;
586-
} else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
587-
cols_per_block = 32;
588-
} else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) {
589-
cols_per_block = 16;
590-
} else {
591-
cols_per_block = 8;
592-
}
582+
int cols_per_block;
583+
if (Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) {
584+
cols_per_block = 64;
585+
} else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
586+
cols_per_block = 32;
587+
} else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) {
588+
cols_per_block = 16;
589+
} else {
590+
cols_per_block = 8;
593591
}
594592
const int frag_m = cols_per_block == 8 ? 32 : 16;
595593
const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m;

0 commit comments

Comments
 (0)