From bbcf67b94bd47d47bc88d12599ae25cc213b3df2 Mon Sep 17 00:00:00 2001 From: ZelinMa557 <3388706467@qq.com> Date: Mon, 5 May 2025 23:23:35 +0800 Subject: [PATCH] [Perf] [CPU] eliminate redundant memory access in group query attention Signed-off-by: ZelinMa557 <3388706467@qq.com> --- ggml/src/ggml-cpu/ggml-cpu.c | 241 ++++++++++++++++++++--------------- 1 file changed, 140 insertions(+), 101 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index f2ab4c5d69582..bab8ecc0715be 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -12152,7 +12152,77 @@ static void ggml_compute_forward_argsort( } // ggml_compute_forward_flash_attn_ext +static inline void ggml_compute_forward_flash_attn_ext_f16_one_QKV( + const ggml_fp16_t *Q, + const char *K, + const char *V, + const int64_t D, + const float mask_value, + const float scale, + const float logit_softcap, + const enum ggml_type v_type, + ggml_vec_dot_t const kq_vec_dot, + ggml_to_float_t const v_to_float, + ggml_fp16_t *VKQ16, + float *VKQ32, + float *V32, + float *sum, + float *max_kq_value) { + float s; // KQ value + kq_vec_dot(D, &s, 0, K, 0, Q, 0, 1); + + s = s*scale; // scale KQ value + + if (logit_softcap != 0.0f) { + s = logit_softcap*tanhf(s); + } + s += mask_value; // apply mask + float M = *max_kq_value; + const float Mold = M; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + if (v_type == GGML_TYPE_F16) { + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, VKQ16, ms); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + // V += v*expf(s - M) + ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) V, vs); + } else { + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f32(D, VKQ32, ms); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + v_to_float(V, V32, D); + + // V += v*expf(s - M) + ggml_vec_mad_f32(D, VKQ32, V32, vs); + } + float S = *sum; + S = S*ms + vs; // scale and increment sum with partial sum + *sum = S; + *max_kq_value = M; +} + +#define GGML_FLASH_ATTN_EXT_MAX_GQA 16 static void ggml_compute_forward_flash_attn_ext_f16( const struct ggml_compute_params * params, const struct ggml_tensor * q, @@ -12179,6 +12249,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(ne0 == D); GGML_ASSERT(ne2 == N); + const int n_gqa = neq2 / nek2; + GGML_ASSERT(n_gqa <= GGML_FLASH_ATTN_EXT_MAX_GQA); // input tensor rows must be contiguous GGML_ASSERT(nbq0 == ggml_type_size(q->type)); GGML_ASSERT(nbk0 == ggml_type_size(k->type)); @@ -12206,15 +12278,15 @@ static void ggml_compute_forward_flash_attn_ext_f16( // parallelize by q rows using ggml_vec_dot_f32 - // total rows in q - const int nr = neq1*neq2*neq3; + // total groups in q + const int ng = neq1*neq2*neq3/n_gqa; - // rows per thread - const int dr = (nr + nth - 1)/nth; + // groups per thread + const int dg = (ng + nth - 1)/nth; - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); + // group range for this thread + const int ig0 = dg*ith; + const int ig1 = MIN(ig0 + dg, ng); float scale = 1.0f; float max_bias = 0.0f; @@ -12242,28 +12314,42 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type"); GGML_ASSERT(v_to_float && "fattn: unsupported V-type"); - // loop over n_batch and n_head - for (int ir = ir0; ir < ir1; ++ir) { + float S[GGML_FLASH_ATTN_EXT_MAX_GQA]; // sum + float M[GGML_FLASH_ATTN_EXT_MAX_GQA]; // maximum KQ value + float * VKQ32[GGML_FLASH_ATTN_EXT_MAX_GQA]; // FP32 VKQ accumulator + float * V32[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) FP32 V buffer + ggml_fp16_t * VKQ16[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) FP16 VKQ accumulator + ggml_fp16_t * Q_q[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) buffer for Q converted to quantized/FP16 + float slope[GGML_FLASH_ATTN_EXT_MAX_GQA]; + + // loop over n_batch and n_group + for (int ig = ig0; ig < ig1; ++ig) { + const int group_index = ig % ng; + const int batch_index = ig / ng; // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + const int iq3 = 0; + const int iq2 = group_index * n_gqa; // start head index + const int iq1 = batch_index; + + for (int i_gqa = 0; i_gqa < n_gqa; ++i_gqa) { + S[i_gqa] = 0.0f; + M[i_gqa] = -INFINITY; + VKQ32 [i_gqa] = (float *) params->wdata + ith*(3*n_gqa*D + CACHE_LINE_SIZE_F32) + 3*i_gqa*D; + V32 [i_gqa] = (VKQ32[i_gqa] + 1*D); + VKQ16 [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 1*D); + Q_q [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 2*D); - const uint32_t h = iq2; // head index - const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; - - float S = 0.0f; // sum - float M = -INFINITY; // maximum KQ value - - float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator - float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer - ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator - ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 + if (v->type == GGML_TYPE_F16) { + memset(VKQ16[i_gqa], 0, 1*D*sizeof(ggml_fp16_t)); + } else { + memset(VKQ32[i_gqa], 0, 1*D*sizeof(float)); + } - if (v->type == GGML_TYPE_F16) { - memset(VKQ16, 0, D*sizeof(ggml_fp16_t)); - } else { - memset(VKQ32, 0, D*sizeof(float)); + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + (iq2 + i_gqa)*nbq2 + iq3*nbq3)); + q_to_vec_dot(pq, Q_q[i_gqa], D); + + const uint32_t h = iq2 + i_gqa; + slope[i_gqa] = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; } const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; @@ -12276,95 +12362,46 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int iv3 = iq3 / rv3; const int iv2 = iq2 / rv2; - const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - q_to_vec_dot(pq, Q_q, D); - // online softmax / attention // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf for (int64_t ic = 0; ic < nek1; ++ic) { - const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f; - if (mv == -INFINITY) { + const float mp_value_base = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; + if (mp_value_base == -INFINITY) { continue; } - - float s; // KQ value - + const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); - kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); - - s = s*scale; // scale KQ value - - if (logit_softcap != 0.0f) { - s = logit_softcap*tanhf(s); + for (int i_gqa = 0; i_gqa < n_gqa; ++i_gqa) { + const float mv = mp_value_base * slope[i_gqa]; + ggml_compute_forward_flash_attn_ext_f16_one_QKV( + Q_q[i_gqa], k_data, v_data, D, mv, scale, logit_softcap, v->type, + kq_vec_dot, v_to_float, VKQ16[i_gqa], VKQ32[i_gqa], V32[i_gqa], S+i_gqa, M+i_gqa); } + } - s += mv; // apply mask - - const float Mold = M; - - float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value - float vs = 1.0f; // post-softmax KQ value, expf(s - M) - - const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - + for (int i = 0; i < n_gqa; ++i) { if (v->type == GGML_TYPE_F16) { - if (s > M) { - // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f - M = s; - ms = expf(Mold - M); - - // V = V*expf(Mold - M) - ggml_vec_scale_f16(D, VKQ16, ms); - } else { - // no new maximum, ms == 1.0f, vs != 1.0f - vs = expf(s - M); - } - - // V += v*expf(s - M) - ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs); - } else { - if (s > M) { - // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f - M = s; - ms = expf(Mold - M); - - // V = V*expf(Mold - M) - ggml_vec_scale_f32(D, VKQ32, ms); - } else { - // no new maximum, ms == 1.0f, vs != 1.0f - vs = expf(s - M); + for (int64_t d = 0; d < D; ++d) { + VKQ32[i][d] = GGML_FP16_TO_FP32(VKQ16[i][d]); } - - v_to_float(v_data, V32, D); - - // V += v*expf(s - M) - ggml_vec_mad_f32(D, VKQ32, V32, vs); } - S = S*ms + vs; // scale and increment sum with partial sum - } + // V /= S + const float S_inv = 1.0f/S[i]; + ggml_vec_scale_f32(D, VKQ32[i], S_inv); - if (v->type == GGML_TYPE_F16) { - for (int64_t d = 0; d < D; ++d) { - VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); - } - } - - // V /= S - const float S_inv = 1.0f/S; - ggml_vec_scale_f32(D, VKQ32, S_inv); - - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // dst indices + const int i1 = iq1; + const int i2 = iq2 + i; + const int i3 = iq3; - // original - //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); - // permute(0, 2, 1, 3) - memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32[i], nb1); + } } } @@ -15132,8 +15169,10 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne00 = node->src[0]->ne[0]; // D - - cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread + const int64_t ne02 = node->src[0]->ne[2]; // n_head + const int64_t ne12 = node->src[1]->ne[2]; // n_head_kv + const int64_t n_gqa = ne02/ne12; + cur = 3*sizeof(float)*ne00*n_tasks*n_gqa; // 3x head size/thread } break; case GGML_OP_FLASH_ATTN_BACK: {