File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -315,8 +315,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
315315
316316    const  bool  gqa_opt_applies = ((Q->ne [2 ] / K->ne [2 ]) % 2  == 0 ) && mask; //  The mma-based kernels have GQA-specific optimizations
317317    const  bool  mma_needs_data_conversion = K->type  != GGML_TYPE_F16 || V->type  != GGML_TYPE_F16;
318-     const  bool  mma_faster_for_bs1 = new_mma_available (cc) && gqa_opt_applies &&
319-         (Q->ne [3 ] > 1  || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion;
318+     const  bool  mma_faster_for_rtx4000 = Q->ne [3 ] > 1  || (Q->ne [2 ] > 4 *K->ne [2 ] && K->ne [1 ] >= 8192 );
319+     const  bool  mma_faster_for_bs1 = new_mma_available (cc) && gqa_opt_applies && !mma_needs_data_conversion &&
320+         (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
320321    const  bool  can_use_vector_kernel = Q->ne [0 ] <= 256  && Q->ne [0 ] % (2 *warp_size) == 0 ;
321322    if  (Q->ne [1 ] == 1  && can_use_vector_kernel && !mma_faster_for_bs1) {
322323        if  (prec == GGML_PREC_DEFAULT) {
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments