File tree Expand file tree Collapse file tree 1 file changed +9
-11
lines changed
Expand file tree Collapse file tree 1 file changed +9
-11
lines changed Original file line number Diff line number Diff line change @@ -579,17 +579,15 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
579579 return ;
580580 }
581581
582- int cols_per_block = 16 ;
583- if (Q->ne [0 ] % 32 == 0 ) {
584- if (Q->ne [1 ] >= 128 && Q->ne [0 ] <= 128 ) {
585- cols_per_block = 64 ;
586- } else if (Q->ne [1 ] >= 64 && (Q->ne [0 ] <= 128 || ggml_cuda_info ().devices [ctx.device ].cc >= CC_AMPERE)) {
587- cols_per_block = 32 ;
588- } else if (Q->ne [1 ] >= 32 || Q->ne [0 ] % 32 != 0 ) {
589- cols_per_block = 16 ;
590- } else {
591- cols_per_block = 8 ;
592- }
582+ int cols_per_block;
583+ if (Q->ne [1 ] >= 128 && Q->ne [0 ] <= 128 && Q->ne [0 ] % 32 == 0 ) {
584+ cols_per_block = 64 ;
585+ } else if (Q->ne [1 ] >= 64 && (Q->ne [0 ] <= 128 || ggml_cuda_info ().devices [ctx.device ].cc >= CC_AMPERE)) {
586+ cols_per_block = 32 ;
587+ } else if (Q->ne [1 ] >= 32 || Q->ne [0 ] % 32 != 0 ) {
588+ cols_per_block = 16 ;
589+ } else {
590+ cols_per_block = 8 ;
593591 }
594592 const int frag_m = cols_per_block == 8 ? 32 : 16 ;
595593 const int nwarps = (Q->ne [0 ] <= 128 || cols_per_block == 8 ? Q->ne [0 ] : Q->ne [0 ]/2 ) / frag_m;
You can’t perform that action at this time.
0 commit comments