File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -446,8 +446,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
446446
447447 const bool gqa_opt_applies = ((Q->ne [2 ] / K->ne [2 ]) % 2 == 0 ) && mask; // The mma-based kernels have GQA-specific optimizations
448448 const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
449- const bool mma_faster_for_bs1 = new_mma_available (cc) && gqa_opt_applies &&
450- (Q->ne [3 ] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion;
449+ const bool mma_faster_for_rtx4000 = Q->ne [3 ] > 1 || (Q->ne [2 ] > 4 *K->ne [2 ] && K->ne [1 ] >= 8192 );
450+ const bool mma_faster_for_bs1 = new_mma_available (cc) && gqa_opt_applies && !mma_needs_data_conversion &&
451+ (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
451452 const bool can_use_vector_kernel = Q->ne [0 ] <= 256 && Q->ne [0 ] % (2 *warp_size) == 0 ;
452453 if (Q->ne [1 ] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
453454 if (prec == GGML_PREC_DEFAULT) {
You can’t perform that action at this time.
0 commit comments