@@ -15239,14 +15239,7 @@ struct FlashQKfp32 {
1523915239 case 7: return std::make_pair(mul_mat<7>, 7);\
1524015240 }\
1524115241 }
15242- if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
15243- #ifdef __aarch64__
15244- MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
15245- #else
15246- MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, nq);
15247- #endif
15248- }
15249- else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
15242+ if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
1525015243#ifdef __aarch64__
1525115244 MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
1525215245#else
@@ -15262,6 +15255,21 @@ struct FlashQKfp32 {
1526215255 MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq);
1526315256#else
1526415257 MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_1, nq);
15258+ #endif
15259+ }
15260+ else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
15261+ #ifdef __aarch64__
15262+ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
15263+ #else
15264+ MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, nq);
15265+ #endif
15266+ }
15267+ #if GGML_IQK_FA_ALL_QUANTS
15268+ else if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
15269+ #ifdef __aarch64__
15270+ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
15271+ #else
15272+ MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, nq);
1526515273#endif
1526615274 }
1526715275 else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
@@ -15278,13 +15286,7 @@ struct FlashQKfp32 {
1527815286 MAKE_FUNCS(mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, nq);
1527915287#endif
1528015288 }
15281- else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
15282- #ifdef __aarch64__
15283- MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
15284- #else
15285- MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, nq);
1528615289#endif
15287- }
1528815290 else {
1528915291 GGML_ASSERT(false);
1529015292 }
@@ -15493,17 +15495,6 @@ struct FlashAttn {
1549315495 template <typename KHelper, typename VHelper>
1549415496 void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
1549515497 const float * q, const char * mask, float * qkv) {
15496- // if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
15497- // std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
15498- // std::is_same_v<KHelper, HelperQ80<D, k_step>> ||
15499- // std::is_same_v<KHelper, HelperQ80R4<D, k_step>> ||
15500- // std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
15501- // compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
15502- // kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
15503- // } else {
15504- // compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
15505- // kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
15506- // }
1550715498 if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
1550815499 std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
1550915500 std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
@@ -16027,6 +16018,11 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
1602716018 HelperQ80<D, k_step> vh(v, stride_v);
1602816019 iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
1602916020 } break;
16021+ case GGML_TYPE_Q6_0: {
16022+ HelperQ60<D, k_step> vh(v, stride_v);
16023+ iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
16024+ } break;
16025+ #if GGML_IQK_FA_ALL_QUANTS
1603016026 case GGML_TYPE_Q4_0: {
1603116027 HelperQ40<D, k_step> vh(v, stride_v);
1603216028 iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
@@ -16039,10 +16035,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
1603916035 HelperIQ4nl<D, k_step> vh(v, stride_v);
1604016036 iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
1604116037 } break;
16042- case GGML_TYPE_Q6_0: {
16043- HelperQ60<D, k_step> vh(v, stride_v);
16044- iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
16045- } break;
16038+ #endif
1604616039 default: break;
1604716040 }
1604816041}
@@ -16062,6 +16055,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
1606216055 HelperQ80<D, k_step> kh(k, stride_k);
1606316056 iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
1606416057 } break;
16058+ case GGML_TYPE_Q6_0: {
16059+ HelperQ60<D, k_step> kh(k, stride_k);
16060+ iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
16061+ } break;
16062+ #if GGML_IQK_FA_ALL_QUANTS
1606516063 case GGML_TYPE_Q4_0: {
1606616064 HelperQ40<D, k_step> kh(k, stride_k);
1606716065 iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
@@ -16074,10 +16072,7 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
1607416072 HelperIQ4nl<D, k_step> kh(k, stride_k);
1607516073 iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
1607616074 } break;
16077- case GGML_TYPE_Q6_0: {
16078- HelperQ60<D, k_step> kh(k, stride_k);
16079- iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
16080- } break;
16075+ #endif
1608116076 default: break;
1608216077 }
1608316078
@@ -16087,8 +16082,12 @@ inline bool flash_attn_is_supported(ggml_type type) {
1608716082#ifdef __AVX512BF16__
1608816083 if (type == GGML_TYPE_BF16) return true;
1608916084#endif
16085+ #if GGML_IQK_FA_ALL_QUANTS
1609016086 if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
1609116087 type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true;
16088+ #else
16089+ if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true;
16090+ #endif
1609216091 return false;
1609316092}
1609416093}
@@ -16115,6 +16114,7 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
1611516114 auto type_v = ggml_type(int_type_v);
1611616115 if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;
1611716116 if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
16117+ if (D != 64 && D != 96 && D != 128 && D != 256) return false;
1611816118
1611916119 auto ck = (const char *)k;
1612016120 auto cv = (const char *)v;
0 commit comments