File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -446,8 +446,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
446446
447447    const  bool  gqa_opt_applies = ((Q->ne [2 ] / K->ne [2 ]) % 2  == 0 ) && mask; //  The mma-based kernels have GQA-specific optimizations
448448    const  bool  mma_needs_data_conversion = K->type  != GGML_TYPE_F16 || V->type  != GGML_TYPE_F16;
449-     const  bool  mma_faster_for_bs1 = new_mma_available (cc) && gqa_opt_applies &&
450-         (Q->ne [3 ] > 1  || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion;
449+     const  bool  mma_faster_for_rtx4000 = Q->ne [3 ] > 1  || (Q->ne [2 ] > 4 *K->ne [2 ] && K->ne [1 ] >= 8192 );
450+     const  bool  mma_faster_for_bs1 = new_mma_available (cc) && gqa_opt_applies && !mma_needs_data_conversion &&
451+         (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
451452    const  bool  can_use_vector_kernel = Q->ne [0 ] <= 256  && Q->ne [0 ] % (2 *warp_size) == 0 ;
452453    if  (Q->ne [1 ] == 1  && can_use_vector_kernel && !mma_faster_for_bs1) {
453454        if  (prec == GGML_PREC_DEFAULT) {
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments