@@ -406,6 +406,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
406406#endif // CP_ASYNC_AVAILABLE
407407
408408#else
409+ GGML_UNUSED (Q_f2); GGML_UNUSED (K_h2); GGML_UNUSED (V_h2);
410+ GGML_UNUSED (mask_h2); GGML_UNUSED (dstk); GGML_UNUSED (dstk_fixup);
411+ GGML_UNUSED (scale); GGML_UNUSED (slope); GGML_UNUSED (logit_softcap);
412+ GGML_UNUSED (ne01); GGML_UNUSED (ne02); GGML_UNUSED (stride_KV);
413+ GGML_UNUSED (stride_mask); GGML_UNUSED (jt); GGML_UNUSED (tile_K);
414+ GGML_UNUSED (stride_mask); GGML_UNUSED (jt); GGML_UNUSED (tile_K);
415+ GGML_UNUSED (tile_V); GGML_UNUSED (tile_mask); GGML_UNUSED (Q_B);
416+ GGML_UNUSED (VKQ_C); GGML_UNUSED (KQ_max); GGML_UNUSED (KQ_rowsum);
417+ GGML_UNUSED (kb0);
409418 NO_DEVICE_CODE;
410419#endif // NEW_MMA_AVAILABLE
411420}
@@ -797,6 +806,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
797806 __syncthreads ();
798807 }
799808#else
809+ GGML_UNUSED (Q_f2); GGML_UNUSED (K_h2); GGML_UNUSED (V_h2);
810+ GGML_UNUSED (mask_h2); GGML_UNUSED (dstk); GGML_UNUSED (dstk_fixup);
811+ GGML_UNUSED (scale); GGML_UNUSED (slope); GGML_UNUSED (logit_softcap);
812+ GGML_UNUSED (ne01); GGML_UNUSED (ne02); GGML_UNUSED (stride_Q1);
813+ GGML_UNUSED (stride_Q2); GGML_UNUSED (stride_KV); GGML_UNUSED (stride_mask);
814+ GGML_UNUSED (jt); GGML_UNUSED (kb0_start); GGML_UNUSED (kb0_stop);
800815 NO_DEVICE_CODE;
801816#endif // NEW_MMA_AVAILABLE
802817}
@@ -931,6 +946,16 @@ static __global__ void flash_attn_ext_f16(
931946 (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
932947 ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
933948#else
949+ GGML_UNUSED (Q); GGML_UNUSED (K); GGML_UNUSED (V); GGML_UNUSED (mask);
950+ GGML_UNUSED (dst); GGML_UNUSED (dst_meta); GGML_UNUSED (scale);
951+ GGML_UNUSED (max_bias); GGML_UNUSED (m0); GGML_UNUSED (m1);
952+ GGML_UNUSED (n_head_log2); GGML_UNUSED (logit_softcap); GGML_UNUSED (ne00);
953+ GGML_UNUSED (ne01); GGML_UNUSED (ne02); GGML_UNUSED (ne03); GGML_UNUSED (ne10);
954+ GGML_UNUSED (ne11); GGML_UNUSED (ne12); GGML_UNUSED (ne13); GGML_UNUSED (ne31);
955+ GGML_UNUSED (nb31); GGML_UNUSED (nb01); GGML_UNUSED (nb02); GGML_UNUSED (nb03);
956+ GGML_UNUSED (nb11); GGML_UNUSED (nb12); GGML_UNUSED (nb13); GGML_UNUSED (nb21);
957+ GGML_UNUSED (nb22); GGML_UNUSED (nb23); GGML_UNUSED (ne0); GGML_UNUSED (ne1);
958+ GGML_UNUSED (ne2); GGML_UNUSED (ne3);
934959 NO_DEVICE_CODE;
935960#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
936961}
@@ -985,38 +1010,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
9851010 extern DECL_FATTN_MMA_F16_CASE (D, (ncols)/4, 4); \
9861011 extern DECL_FATTN_MMA_F16_CASE (D, (ncols)/8, 8); \
9871012
988- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 ( 64 , 8 );
989- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 ( 80 , 8 );
990- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 ( 96 , 8 );
991- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 (112 , 8 );
992- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 (128 , 8 );
993- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 (256 , 8 );
994-
995- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 ( 64 , 16 );
996- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 ( 80 , 16 );
997- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 ( 96 , 16 );
998- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 (112 , 16 );
999- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 (128 , 16 );
1000- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 (256 , 16 );
1001-
1002- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 ( 64 , 32 );
1003- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 ( 80 , 32 );
1004- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 ( 96 , 32 );
1005- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 (112 , 32 );
1006- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 (128 , 32 );
1007- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 (256 , 32 );
1008-
1009- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 ( 64 , 64 );
1010- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 ( 80 , 64 );
1011- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 ( 96 , 64 );
1012- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 (112 , 64 );
1013- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 (128 , 64 );
1014- DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 (256 , 64 );
1013+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2 ( 64 , 8 )
1014+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80 , 8 )
1015+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96 , 8 )
1016+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112 , 8 )
1017+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128 , 8 )
1018+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256 , 8 )
1019+
1020+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64 , 16 )
1021+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80 , 16 )
1022+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96 , 16 )
1023+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112 , 16 )
1024+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128 , 16 )
1025+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256 , 16 )
1026+
1027+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64 , 32 )
1028+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80 , 32 )
1029+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96 , 32 )
1030+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112 , 32 )
1031+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128 , 32 )
1032+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256 , 32 )
1033+
1034+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64 , 64 )
1035+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80 , 64 )
1036+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96 , 64 )
1037+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112 , 64 )
1038+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128 , 64 )
1039+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256 , 64 )
10151040
10161041// Kernels with ncols == 128 are only 4% faster due to register pressure.
1017- // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128);
1018- // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128);
1019- // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128);
1020- // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128);
1021- // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128);
1022- // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory.
1042+ // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128)
1043+ // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128)
1044+ // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128)
1045+ // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128)
1046+ // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128)
1047+ // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory.
0 commit comments