@@ -208,8 +208,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
208208    GGML_ASSERT (Q->ne [2 ] % K->ne [2 ] == 0 );
209209
210210    const  int  cc = ggml_cuda_info ().devices [device].cc ;
211-     const  int  warp_size = ggml_cuda_info ().devices [device].warp_size ;
212-     const  enum  ggml_prec prec = ggml_flash_attn_ext_get_prec (KQV);
213211
214212    switch  (K->ne [0 ]) {
215213        case   64 :
@@ -267,29 +265,31 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
267265        return  BEST_FATTN_KERNEL_NONE;
268266    }
269267
270-     const  bool  can_use_vector_kernel = Q->ne [0 ] <= 256  && Q->ne [0 ] % ( 2 *warp_size)  == 0 ;
268+     const  bool  can_use_vector_kernel = Q->ne [0 ] <= 256  && Q->ne [0 ] % 64  == 0 ;
271269
272270    //  If Turing tensor cores available, use them except for some cases with batch size 1:
273271    if  (turing_mma_available (cc)) {
274272        best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
275273
276-         if  (K->type  == GGML_TYPE_F16 && V->type  == GGML_TYPE_F16) {
277-             if  (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne [1 ] == 1  && Q->ne [3 ] == 1  && !(gqa_ratio > 4  && K->ne [1 ] >= 8192 )) {
278-                 best = BEST_FATTN_KERNEL_VEC;
279-             }
280-         } else  {
281-             if  (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
282-                 if  (Q->ne [1 ] <= 2 ) {
274+         if  (can_use_vector_kernel) {
275+             if  (K->type  == GGML_TYPE_F16 && V->type  == GGML_TYPE_F16) {
276+                 if  (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne [1 ] == 1  && Q->ne [3 ] == 1  && !(gqa_ratio > 4  && K->ne [1 ] >= 8192 )) {
283277                    best = BEST_FATTN_KERNEL_VEC;
284278                }
285279            } else  {
286-                 if  (Q->ne [1 ] == 1 ) {
287-                     best = BEST_FATTN_KERNEL_VEC;
280+                 if  (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
281+                     if  (Q->ne [1 ] <= 2 ) {
282+                         best = BEST_FATTN_KERNEL_VEC;
283+                     }
284+                 } else  {
285+                     if  (Q->ne [1 ] == 1 ) {
286+                         best = BEST_FATTN_KERNEL_VEC;
287+                     }
288288                }
289289            }
290-         } 
291-         if  ((gqa_ratio %  2  !=  0  || !mask) && Q-> ne [ 1 ] ==  1 ) { 
292-             best = BEST_FATTN_KERNEL_VEC;  //  GQA-specific optimizations in the mma kernel do not apply. 
290+              if  ((gqa_ratio %  2  !=  0  || !mask) && Q-> ne [ 1 ] ==  1 ) { 
291+                 best = BEST_FATTN_KERNEL_VEC;  //  GQA-specific optimizations in the mma kernel do not apply. 
292+             } 
293293        }
294294
295295        return  best;
0 commit comments