@@ -190,7 +190,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
190190 FATTN_VEC_F16_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
191191#endif // GGML_CUDA_FA_ALL_QUANTS
192192
193- on_no_fattn_vec_case (Q-> ne [ 0 ] );
193+ GGML_ABORT ( " fatal error " );
194194}
195195
196196#define FATTN_VEC_F32_CASE (D, type_K, type_V ) \
@@ -265,74 +265,184 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
265265 FATTN_VEC_F32_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
266266#endif // GGML_CUDA_FA_ALL_QUANTS
267267
268- on_no_fattn_vec_case (Q-> ne [ 0 ] );
268+ GGML_ABORT ( " fatal error " );
269269}
270270
271- void ggml_cuda_flash_attn_ext (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
271+ // Best FlashAttention kernel for a specific GPU:
272+ enum best_fattn_kernel {
273+ BEST_FATTN_KERNEL_NONE = 0 ,
274+ BEST_FATTN_KERNEL_TILE_F32 = 200 ,
275+ BEST_FATTN_KERNEL_TILE_F16 = 210 ,
276+ BEST_FATTN_KERNEL_VEC_F32 = 100 ,
277+ BEST_FATTN_KERNEL_VEC_F16 = 110 ,
278+ BEST_FATTN_KERNEL_WMMA_F16 = 300 ,
279+ BEST_FATTN_KERNEL_MMA_F16 = 400 ,
280+ };
281+
282+ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel (const int device, const ggml_tensor * dst) {
283+ #ifndef FLASH_ATTN_AVAILABLE
284+ GGML_UNUSED (device); GGML_UNUSED (dst);
285+ return BEST_FATTN_KERNEL_NONE;
286+ #endif // FLASH_ATTN_AVAILABLE
287+
272288 const ggml_tensor * KQV = dst;
273289 const ggml_tensor * Q = dst->src [0 ];
274290 const ggml_tensor * K = dst->src [1 ];
275291 const ggml_tensor * V = dst->src [2 ];
276292 const ggml_tensor * mask = dst->src [3 ];
277293
278- ggml_cuda_set_device (ctx.device );
279- const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
280- const int warp_size = ggml_cuda_info ().devices [ggml_cuda_get_device ()].warp_size ;
294+ const int gqa_ratio = Q->ne [2 ] / K->ne [2 ];
295+ GGML_ASSERT (Q->ne [2 ] % K->ne [2 ] == 0 );
296+
297+ const int cc = ggml_cuda_info ().devices [device].cc ;
298+ const int warp_size = ggml_cuda_info ().devices [device].warp_size ;
281299 const enum ggml_prec prec = ggml_flash_attn_ext_get_prec (KQV);
282300
283- #if defined(GGML_HIP_ROCWMMA_FATTN)
284- if (GGML_CUDA_CC_IS_AMD (cc) && fp16_mma_available (cc)) {
285- ggml_cuda_flash_attn_ext_wmma_f16 (ctx, dst);
286- return ;
301+ switch (K->ne [0 ]) {
302+ case 64 :
303+ case 128 :
304+ case 256 :
305+ if (V->ne [0 ] != K->ne [0 ]) {
306+ return BEST_FATTN_KERNEL_NONE;
307+ }
308+ break ;
309+ case 80 :
310+ case 96 :
311+ case 112 :
312+ if (V->ne [0 ] != K->ne [0 ]) {
313+ return BEST_FATTN_KERNEL_NONE;
314+ }
315+ if (!fp16_mma_available (cc) && !turing_mma_available (cc)) {
316+ return BEST_FATTN_KERNEL_NONE;
317+ }
318+ break ;
319+ case 576 :
320+ if (V->ne [0 ] != 512 ) {
321+ return BEST_FATTN_KERNEL_NONE;
322+ }
323+ if (!turing_mma_available (cc) || gqa_ratio % 16 != 0 ) {
324+ return BEST_FATTN_KERNEL_NONE;
325+ }
326+ break ;
327+ default :
328+ return BEST_FATTN_KERNEL_NONE;
287329 }
288- #endif // defined(GGML_HIP_ROCWMMA_FATTN)
289330
290- if (!fast_fp16_available (cc)) {
291- if (Q->ne [1 ] <= 8 || Q->ne [0 ] == 256 ) {
292- ggml_cuda_flash_attn_ext_vec_f32 (ctx, dst);
293- } else {
294- ggml_cuda_flash_attn_ext_tile_f32 (ctx, dst);
295- }
296- return ;
331+ #ifndef GGML_CUDA_FA_ALL_QUANTS
332+ if (K->type != V->type ) {
333+ return BEST_FATTN_KERNEL_NONE;
297334 }
335+ #endif // GGML_CUDA_FA_ALL_QUANTS
298336
299- if (!fp16_mma_available (cc)) {
300- if (prec == GGML_PREC_DEFAULT) {
301- if (Q->ne [1 ] <= 8 || Q->ne [0 ] == 256 ) {
302- ggml_cuda_flash_attn_ext_vec_f16 (ctx, dst);
303- } else {
304- ggml_cuda_flash_attn_ext_tile_f16 (ctx, dst);
337+ switch (K->type ) {
338+ case GGML_TYPE_F16:
339+ break ;
340+ case GGML_TYPE_Q4_1:
341+ case GGML_TYPE_Q5_0:
342+ case GGML_TYPE_Q5_1:
343+ #ifndef GGML_CUDA_FA_ALL_QUANTS
344+ return BEST_FATTN_KERNEL_NONE;
345+ #endif // GGML_CUDA_FA_ALL_QUANTS
346+ case GGML_TYPE_Q4_0:
347+ case GGML_TYPE_Q8_0:
348+ #ifdef GGML_CUDA_FA_ALL_QUANTS
349+ if (K->ne [0 ] != 128 && K->ne [0 ] != 64 ) {
350+ return BEST_FATTN_KERNEL_NONE;
305351 }
306- } else {
307- if (Q->ne [1 ] <= 8 || Q->ne [0 ] == 256 ) {
308- ggml_cuda_flash_attn_ext_vec_f32 (ctx, dst);
309- } else {
310- ggml_cuda_flash_attn_ext_tile_f32 (ctx, dst);
352+ #else
353+ if (K->ne [0 ] != 128 ) {
354+ return BEST_FATTN_KERNEL_NONE;
311355 }
312- }
313- return ;
356+ #endif // GGML_CUDA_FA_ALL_QUANTS
357+ break ;
358+ default :
359+ return BEST_FATTN_KERNEL_NONE;
360+ }
361+
362+ switch (V->type ) {
363+ case GGML_TYPE_F16:
364+ break ;
365+ case GGML_TYPE_Q4_1:
366+ case GGML_TYPE_Q5_0:
367+ case GGML_TYPE_Q5_1:
368+ case GGML_TYPE_Q4_0:
369+ case GGML_TYPE_Q8_0:
370+ if (K->ne [0 ] != 128 ) {
371+ return BEST_FATTN_KERNEL_NONE;
372+ }
373+ break ;
374+ default :
375+ return BEST_FATTN_KERNEL_NONE;
376+ }
377+
378+ if (mask && mask->ne [2 ] != 1 ) {
379+ return BEST_FATTN_KERNEL_NONE;
314380 }
315381
316- const bool gqa_opt_applies = ((Q->ne [2 ] / K->ne [2 ]) % 2 == 0 ) && mask; // The mma-based kernels have GQA-specific optimizations
317- const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
318- const bool mma_faster_for_rtx4000 = Q->ne [3 ] > 1 || (Q->ne [2 ] > 4 *K->ne [2 ] && K->ne [1 ] >= 8192 );
319- const bool mma_faster_for_bs1 = turing_mma_available (cc) && gqa_opt_applies && !mma_needs_data_conversion &&
320- (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
321382 const bool can_use_vector_kernel = Q->ne [0 ] <= 256 && Q->ne [0 ] % (2 *warp_size) == 0 ;
322- if (Q->ne [1 ] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
323- if (prec == GGML_PREC_DEFAULT) {
324- ggml_cuda_flash_attn_ext_vec_f16 (ctx, dst);
325- } else {
326- ggml_cuda_flash_attn_ext_vec_f32 (ctx, dst);
383+
384+ // If Turing tensor cores available, use them except for some cases with batch size 1:
385+ if (turing_mma_available (cc)) {
386+ const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask; // The mma-based kernels have GQA-specific optimizations
387+ const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
388+ const bool mma_faster_for_rtx4000 = Q->ne [3 ] > 1 || (gqa_ratio > 4 && K->ne [1 ] >= 8192 );
389+ const bool mma_faster_for_bs1 = gqa_opt_applies && !mma_needs_data_conversion &&
390+ (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
391+ if (Q->ne [1 ] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
392+ if (prec == GGML_PREC_DEFAULT && fast_fp16_available (cc)) {
393+ return BEST_FATTN_KERNEL_VEC_F16;
394+ }
395+ return BEST_FATTN_KERNEL_VEC_F32;
327396 }
328- return ;
397+ return BEST_FATTN_KERNEL_MMA_F16 ;
329398 }
330399
331- // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
332- if (fp16_mma_available (cc) && !turing_mma_available (cc)) {
333- ggml_cuda_flash_attn_ext_wmma_f16 (ctx, dst);
334- return ;
400+ // Use kernels specializes for small batch sizes if possible:
401+ if (Q->ne [1 ] <= 8 && can_use_vector_kernel) {
402+ if (prec == GGML_PREC_DEFAULT && fast_fp16_available (cc)) {
403+ return BEST_FATTN_KERNEL_VEC_F16;
404+ }
405+ return BEST_FATTN_KERNEL_VEC_F32;
406+ }
407+
408+ // For large batch sizes, use the WMMA kernel if possible:
409+ if (fp16_mma_available (cc)) {
410+ return BEST_FATTN_KERNEL_WMMA_F16;
411+ }
412+
413+ // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
414+ if (prec == GGML_PREC_DEFAULT && fast_fp16_available (cc)) {
415+ return BEST_FATTN_KERNEL_TILE_F16;
335416 }
417+ return BEST_FATTN_KERNEL_TILE_F32;
418+ }
419+
420+ void ggml_cuda_flash_attn_ext (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
421+ ggml_cuda_set_device (ctx.device );
422+ switch (ggml_cuda_get_best_fattn_kernel (ggml_cuda_get_device (), dst)) {
423+ case BEST_FATTN_KERNEL_NONE:
424+ GGML_ABORT (" fatal error" );
425+ case BEST_FATTN_KERNEL_TILE_F32:
426+ ggml_cuda_flash_attn_ext_tile_f32 (ctx, dst);
427+ break ;
428+ case BEST_FATTN_KERNEL_TILE_F16:
429+ ggml_cuda_flash_attn_ext_tile_f16 (ctx, dst);
430+ break ;
431+ case BEST_FATTN_KERNEL_VEC_F32:
432+ ggml_cuda_flash_attn_ext_vec_f32 (ctx, dst);
433+ break ;
434+ case BEST_FATTN_KERNEL_VEC_F16:
435+ ggml_cuda_flash_attn_ext_vec_f16 (ctx, dst);
436+ break ;
437+ case BEST_FATTN_KERNEL_WMMA_F16:
438+ ggml_cuda_flash_attn_ext_wmma_f16 (ctx, dst);
439+ break ;
440+ case BEST_FATTN_KERNEL_MMA_F16:
441+ ggml_cuda_flash_attn_ext_mma_f16 (ctx, dst);
442+ break ;
443+ }
444+ }
336445
337- ggml_cuda_flash_attn_ext_mma_f16 (ctx, dst);
446+ bool ggml_cuda_flash_attn_ext_supported (int device, const ggml_tensor * dst) {
447+ return ggml_cuda_get_best_fattn_kernel (device, dst) != BEST_FATTN_KERNEL_NONE;
338448}
0 commit comments