Skip to content

Commit cae2b81

Browse files
ikawrakowIwan Kawrakow
andauthored
FA: Add option to build all FA kernels (#197)
Similar to the CUDA situation. It is OFF by default. If OFF, only F16, Q8_0, Q6_0, and, if the CPU provides native BF16 support, BF16 FA kernels will be included. To enable all, cmake -DGGML_IQK_FA_ALL_QUANTS=1 ... This cuts compilation time for iqk_mul_mat.cpp by almost half (45 seconds vs 81 seconds on my Ryzen-7950X). Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 33390c4 commit cae2b81

File tree

3 files changed

+39
-33
lines changed

3 files changed

+39
-33
lines changed

ggml/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM"
130130
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
131131
option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF)
132132

133+
option(GGML_IQK_FA_ALL_QUANTS "ggml: compile all quants for IQK FlashAttention" OFF)
134+
133135
option(GGML_CURL "ggml: use libcurl to download model from an URL" OFF)
134136
option(GGML_HIPBLAS "ggml: use hipBLAS" OFF)
135137
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)

ggml/src/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,10 @@ if (GGML_IQK_MUL_MAT)
259259
add_compile_definitions(GGML_USE_IQK_MULMAT)
260260
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp)
261261
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h)
262+
if (GGML_IQK_FA_ALL_QUANTS)
263+
message(STATUS "Including all IQK FA kernels")
264+
add_compile_definitions(GGML_IQK_FA_ALL_QUANTS)
265+
endif()
262266
endif()
263267

264268
if (GGML_LLAMAFILE)

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15239,14 +15239,7 @@ struct FlashQKfp32 {
1523915239
case 7: return std::make_pair(mul_mat<7>, 7);\
1524015240
}\
1524115241
}
15242-
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
15243-
#ifdef __aarch64__
15244-
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
15245-
#else
15246-
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, nq);
15247-
#endif
15248-
}
15249-
else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
15242+
if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
1525015243
#ifdef __aarch64__
1525115244
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
1525215245
#else
@@ -15262,6 +15255,21 @@ struct FlashQKfp32 {
1526215255
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq);
1526315256
#else
1526415257
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_1, nq);
15258+
#endif
15259+
}
15260+
else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
15261+
#ifdef __aarch64__
15262+
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
15263+
#else
15264+
MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, nq);
15265+
#endif
15266+
}
15267+
#if GGML_IQK_FA_ALL_QUANTS
15268+
else if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
15269+
#ifdef __aarch64__
15270+
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
15271+
#else
15272+
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, nq);
1526515273
#endif
1526615274
}
1526715275
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
@@ -15278,13 +15286,7 @@ struct FlashQKfp32 {
1527815286
MAKE_FUNCS(mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, nq);
1527915287
#endif
1528015288
}
15281-
else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
15282-
#ifdef __aarch64__
15283-
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
15284-
#else
15285-
MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, nq);
1528615289
#endif
15287-
}
1528815290
else {
1528915291
GGML_ASSERT(false);
1529015292
}
@@ -15493,17 +15495,6 @@ struct FlashAttn {
1549315495
template <typename KHelper, typename VHelper>
1549415496
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
1549515497
const float * q, const char * mask, float * qkv) {
15496-
// if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
15497-
// std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
15498-
// std::is_same_v<KHelper, HelperQ80<D, k_step>> ||
15499-
// std::is_same_v<KHelper, HelperQ80R4<D, k_step>> ||
15500-
// std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
15501-
// compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
15502-
// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
15503-
// } else {
15504-
// compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
15505-
// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
15506-
// }
1550715498
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
1550815499
std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
1550915500
std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
@@ -16027,6 +16018,11 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
1602716018
HelperQ80<D, k_step> vh(v, stride_v);
1602816019
iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
1602916020
} break;
16021+
case GGML_TYPE_Q6_0: {
16022+
HelperQ60<D, k_step> vh(v, stride_v);
16023+
iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
16024+
} break;
16025+
#if GGML_IQK_FA_ALL_QUANTS
1603016026
case GGML_TYPE_Q4_0: {
1603116027
HelperQ40<D, k_step> vh(v, stride_v);
1603216028
iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
@@ -16039,10 +16035,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
1603916035
HelperIQ4nl<D, k_step> vh(v, stride_v);
1604016036
iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
1604116037
} break;
16042-
case GGML_TYPE_Q6_0: {
16043-
HelperQ60<D, k_step> vh(v, stride_v);
16044-
iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
16045-
} break;
16038+
#endif
1604616039
default: break;
1604716040
}
1604816041
}
@@ -16062,6 +16055,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
1606216055
HelperQ80<D, k_step> kh(k, stride_k);
1606316056
iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
1606416057
} break;
16058+
case GGML_TYPE_Q6_0: {
16059+
HelperQ60<D, k_step> kh(k, stride_k);
16060+
iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
16061+
} break;
16062+
#if GGML_IQK_FA_ALL_QUANTS
1606516063
case GGML_TYPE_Q4_0: {
1606616064
HelperQ40<D, k_step> kh(k, stride_k);
1606716065
iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
@@ -16074,10 +16072,7 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
1607416072
HelperIQ4nl<D, k_step> kh(k, stride_k);
1607516073
iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
1607616074
} break;
16077-
case GGML_TYPE_Q6_0: {
16078-
HelperQ60<D, k_step> kh(k, stride_k);
16079-
iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
16080-
} break;
16075+
#endif
1608116076
default: break;
1608216077
}
1608316078

@@ -16087,8 +16082,12 @@ inline bool flash_attn_is_supported(ggml_type type) {
1608716082
#ifdef __AVX512BF16__
1608816083
if (type == GGML_TYPE_BF16) return true;
1608916084
#endif
16085+
#if GGML_IQK_FA_ALL_QUANTS
1609016086
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
1609116087
type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true;
16088+
#else
16089+
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true;
16090+
#endif
1609216091
return false;
1609316092
}
1609416093
}
@@ -16115,6 +16114,7 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
1611516114
auto type_v = ggml_type(int_type_v);
1611616115
if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;
1611716116
if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
16117+
if (D != 64 && D != 96 && D != 128 && D != 256) return false;
1611816118

1611916119
auto ck = (const char *)k;
1612016120
auto cv = (const char *)v;

0 commit comments

Comments
 (0)