File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -315,8 +315,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
315
315
316
316
const bool gqa_opt_applies = ((Q->ne [2 ] / K->ne [2 ]) % 2 == 0 ) && mask; // The mma-based kernels have GQA-specific optimizations
317
317
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
318
- const bool mma_faster_for_bs1 = new_mma_available (cc) && gqa_opt_applies &&
319
- (Q->ne [3 ] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion;
318
+ const bool mma_faster_for_rtx4000 = Q->ne [3 ] > 1 || (Q->ne [2 ] > 4 *K->ne [2 ] && K->ne [1 ] >= 8192 );
319
+ const bool mma_faster_for_bs1 = new_mma_available (cc) && gqa_opt_applies && !mma_needs_data_conversion &&
320
+ (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
320
321
const bool can_use_vector_kernel = Q->ne [0 ] <= 256 && Q->ne [0 ] % (2 *warp_size) == 0 ;
321
322
if (Q->ne [1 ] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
322
323
if (prec == GGML_PREC_DEFAULT) {
You can’t perform that action at this time.
0 commit comments