@@ -625,9 +625,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
625625 }
626626
627627 int cols_per_block;
628- if (false && Q->ne [1 ] >= 128 && Q->ne [0 ] <= 128 && Q->ne [0 ] % 32 == 0 ) {
629- cols_per_block = 64 ;
630- } else if (Q->ne [1 ] >= 64 && (Q->ne [0 ] <= 128 || ggml_cuda_info ().devices [ctx.device ].cc >= CC_AMPERE)) {
628+ if (Q->ne [1 ] >= 64 && (Q->ne [0 ] <= 128 || ggml_cuda_info ().devices [ctx.device ].cc >= CC_AMPERE)) {
631629 cols_per_block = 32 ;
632630 } else if (Q->ne [1 ] >= 32 || Q->ne [0 ] % 32 != 0 ) {
633631 cols_per_block = 16 ;
@@ -645,7 +643,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
645643 FATTN_SWITCH_CASE (64 , 8 , nwarps);
646644 FATTN_SWITCH_CASE (64 , 16 , nwarps);
647645 FATTN_SWITCH_CASE (64 , 32 , nwarps);
648- FATTN_SWITCH_CASE (64 , 64 , nwarps);
649646 default :
650647 fprintf (stderr, " cols_per_block == %d not implemented.\n " , cols_per_block);
651648 GGML_ASSERT (false );
@@ -655,7 +652,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
655652 // FATTN_SWITCH_CASE(80, 8, nwarps);
656653 FATTN_SWITCH_CASE (80 , 16 , nwarps);
657654 FATTN_SWITCH_CASE (80 , 32 , nwarps);
658- // FATTN_SWITCH_CASE(80, 64, nwarps);
659655 default :
660656 fprintf (stderr, " cols_per_block == %d not implemented.\n " , cols_per_block);
661657 GGML_ASSERT (false );
@@ -665,7 +661,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
665661 FATTN_SWITCH_CASE (96 , 8 , nwarps);
666662 FATTN_SWITCH_CASE (96 , 16 , nwarps);
667663 FATTN_SWITCH_CASE (96 , 32 , nwarps);
668- FATTN_SWITCH_CASE (96 , 64 , nwarps);
669664 default :
670665 fprintf (stderr, " cols_per_block == %d not implemented.\n " , cols_per_block);
671666 GGML_ASSERT (false );
@@ -675,7 +670,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
675670 // FATTN_SWITCH_CASE(112, 8, nwarps);
676671 FATTN_SWITCH_CASE (112 , 16 , nwarps);
677672 FATTN_SWITCH_CASE (112 , 32 , nwarps);
678- // FATTN_SWITCH_CASE(112, 64, nwarps);
679673 default :
680674 fprintf (stderr, " cols_per_block == %d not implemented.\n " , cols_per_block);
681675 GGML_ASSERT (false );
@@ -685,7 +679,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
685679 FATTN_SWITCH_CASE (128 , 8 , nwarps);
686680 FATTN_SWITCH_CASE (128 , 16 , nwarps);
687681 FATTN_SWITCH_CASE (128 , 32 , nwarps);
688- // FATTN_SWITCH_CASE(128, 64, nwarps);
689682 default :
690683 fprintf (stderr, " cols_per_block == %d not implemented.\n " , cols_per_block);
691684 GGML_ASSERT (false );
@@ -695,7 +688,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
695688 FATTN_SWITCH_CASE (256 , 8 , nwarps);
696689 FATTN_SWITCH_CASE (256 , 16 , nwarps);
697690 FATTN_SWITCH_CASE (256 , 32 , nwarps);
698- // FATTN_SWITCH_CASE(256, 64, nwarps);
699691 default :
700692 fprintf (stderr, " cols_per_block == %d not implemented.\n " , cols_per_block);
701693 GGML_ASSERT (false );
0 commit comments