Skip to content

Commit bbcf67b

Browse files
committed
[Perf] [CPU] eliminate redundant memory access in group query attention
Signed-off-by: ZelinMa557 <[email protected]>
1 parent f4c3dd5 commit bbcf67b

File tree

1 file changed

+140
-101
lines changed

1 file changed

+140
-101
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 140 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -12152,7 +12152,77 @@ static void ggml_compute_forward_argsort(
1215212152
}
1215312153

1215412154
// ggml_compute_forward_flash_attn_ext
12155+
static inline void ggml_compute_forward_flash_attn_ext_f16_one_QKV(
12156+
const ggml_fp16_t *Q,
12157+
const char *K,
12158+
const char *V,
12159+
const int64_t D,
12160+
const float mask_value,
12161+
const float scale,
12162+
const float logit_softcap,
12163+
const enum ggml_type v_type,
12164+
ggml_vec_dot_t const kq_vec_dot,
12165+
ggml_to_float_t const v_to_float,
12166+
ggml_fp16_t *VKQ16,
12167+
float *VKQ32,
12168+
float *V32,
12169+
float *sum,
12170+
float *max_kq_value) {
12171+
float s; // KQ value
12172+
kq_vec_dot(D, &s, 0, K, 0, Q, 0, 1);
12173+
12174+
s = s*scale; // scale KQ value
12175+
12176+
if (logit_softcap != 0.0f) {
12177+
s = logit_softcap*tanhf(s);
12178+
}
12179+
s += mask_value; // apply mask
12180+
float M = *max_kq_value;
12181+
const float Mold = M;
12182+
12183+
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
12184+
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
12185+
12186+
if (v_type == GGML_TYPE_F16) {
12187+
if (s > M) {
12188+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
12189+
M = s;
12190+
ms = expf(Mold - M);
12191+
12192+
// V = V*expf(Mold - M)
12193+
ggml_vec_scale_f16(D, VKQ16, ms);
12194+
} else {
12195+
// no new maximum, ms == 1.0f, vs != 1.0f
12196+
vs = expf(s - M);
12197+
}
12198+
12199+
// V += v*expf(s - M)
12200+
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) V, vs);
12201+
} else {
12202+
if (s > M) {
12203+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
12204+
M = s;
12205+
ms = expf(Mold - M);
12206+
12207+
// V = V*expf(Mold - M)
12208+
ggml_vec_scale_f32(D, VKQ32, ms);
12209+
} else {
12210+
// no new maximum, ms == 1.0f, vs != 1.0f
12211+
vs = expf(s - M);
12212+
}
1215512213

12214+
v_to_float(V, V32, D);
12215+
12216+
// V += v*expf(s - M)
12217+
ggml_vec_mad_f32(D, VKQ32, V32, vs);
12218+
}
12219+
float S = *sum;
12220+
S = S*ms + vs; // scale and increment sum with partial sum
12221+
*sum = S;
12222+
*max_kq_value = M;
12223+
}
12224+
12225+
#define GGML_FLASH_ATTN_EXT_MAX_GQA 16
1215612226
static void ggml_compute_forward_flash_attn_ext_f16(
1215712227
const struct ggml_compute_params * params,
1215812228
const struct ggml_tensor * q,
@@ -12179,6 +12249,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1217912249
GGML_ASSERT(ne0 == D);
1218012250
GGML_ASSERT(ne2 == N);
1218112251

12252+
const int n_gqa = neq2 / nek2;
12253+
GGML_ASSERT(n_gqa <= GGML_FLASH_ATTN_EXT_MAX_GQA);
1218212254
// input tensor rows must be contiguous
1218312255
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
1218412256
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
@@ -12206,15 +12278,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1220612278

1220712279
// parallelize by q rows using ggml_vec_dot_f32
1220812280

12209-
// total rows in q
12210-
const int nr = neq1*neq2*neq3;
12281+
// total groups in q
12282+
const int ng = neq1*neq2*neq3/n_gqa;
1221112283

12212-
// rows per thread
12213-
const int dr = (nr + nth - 1)/nth;
12284+
// groups per thread
12285+
const int dg = (ng + nth - 1)/nth;
1221412286

12215-
// row range for this thread
12216-
const int ir0 = dr*ith;
12217-
const int ir1 = MIN(ir0 + dr, nr);
12287+
// group range for this thread
12288+
const int ig0 = dg*ith;
12289+
const int ig1 = MIN(ig0 + dg, ng);
1221812290

1221912291
float scale = 1.0f;
1222012292
float max_bias = 0.0f;
@@ -12242,28 +12314,42 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1224212314
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
1224312315
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
1224412316

12245-
// loop over n_batch and n_head
12246-
for (int ir = ir0; ir < ir1; ++ir) {
12317+
float S[GGML_FLASH_ATTN_EXT_MAX_GQA]; // sum
12318+
float M[GGML_FLASH_ATTN_EXT_MAX_GQA]; // maximum KQ value
12319+
float * VKQ32[GGML_FLASH_ATTN_EXT_MAX_GQA]; // FP32 VKQ accumulator
12320+
float * V32[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) FP32 V buffer
12321+
ggml_fp16_t * VKQ16[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) FP16 VKQ accumulator
12322+
ggml_fp16_t * Q_q[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) buffer for Q converted to quantized/FP16
12323+
float slope[GGML_FLASH_ATTN_EXT_MAX_GQA];
12324+
12325+
// loop over n_batch and n_group
12326+
for (int ig = ig0; ig < ig1; ++ig) {
12327+
const int group_index = ig % ng;
12328+
const int batch_index = ig / ng;
1224712329
// q indices
12248-
const int iq3 = ir/(neq2*neq1);
12249-
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
12250-
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
12330+
const int iq3 = 0;
12331+
const int iq2 = group_index * n_gqa; // start head index
12332+
const int iq1 = batch_index;
12333+
12334+
for (int i_gqa = 0; i_gqa < n_gqa; ++i_gqa) {
12335+
S[i_gqa] = 0.0f;
12336+
M[i_gqa] = -INFINITY;
12337+
VKQ32 [i_gqa] = (float *) params->wdata + ith*(3*n_gqa*D + CACHE_LINE_SIZE_F32) + 3*i_gqa*D;
12338+
V32 [i_gqa] = (VKQ32[i_gqa] + 1*D);
12339+
VKQ16 [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 1*D);
12340+
Q_q [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 2*D);
1225112341

12252-
const uint32_t h = iq2; // head index
12253-
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
12254-
12255-
float S = 0.0f; // sum
12256-
float M = -INFINITY; // maximum KQ value
12257-
12258-
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
12259-
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
12260-
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
12261-
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
12342+
if (v->type == GGML_TYPE_F16) {
12343+
memset(VKQ16[i_gqa], 0, 1*D*sizeof(ggml_fp16_t));
12344+
} else {
12345+
memset(VKQ32[i_gqa], 0, 1*D*sizeof(float));
12346+
}
1226212347

12263-
if (v->type == GGML_TYPE_F16) {
12264-
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
12265-
} else {
12266-
memset(VKQ32, 0, D*sizeof(float));
12348+
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + (iq2 + i_gqa)*nbq2 + iq3*nbq3));
12349+
q_to_vec_dot(pq, Q_q[i_gqa], D);
12350+
12351+
const uint32_t h = iq2 + i_gqa;
12352+
slope[i_gqa] = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
1226712353
}
1226812354

1226912355
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
@@ -12276,95 +12362,46 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1227612362
const int iv3 = iq3 / rv3;
1227712363
const int iv2 = iq2 / rv2;
1227812364

12279-
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
12280-
q_to_vec_dot(pq, Q_q, D);
12281-
1228212365
// online softmax / attention
1228312366
// loop over n_kv and n_head_kv
1228412367
// ref: https://arxiv.org/pdf/2112.05682.pdf
1228512368
for (int64_t ic = 0; ic < nek1; ++ic) {
12286-
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
12287-
if (mv == -INFINITY) {
12369+
const float mp_value_base = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
12370+
if (mp_value_base == -INFINITY) {
1228812371
continue;
1228912372
}
12290-
12291-
float s; // KQ value
12292-
12373+
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
1229312374
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
12294-
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
12295-
12296-
s = s*scale; // scale KQ value
12297-
12298-
if (logit_softcap != 0.0f) {
12299-
s = logit_softcap*tanhf(s);
12375+
for (int i_gqa = 0; i_gqa < n_gqa; ++i_gqa) {
12376+
const float mv = mp_value_base * slope[i_gqa];
12377+
ggml_compute_forward_flash_attn_ext_f16_one_QKV(
12378+
Q_q[i_gqa], k_data, v_data, D, mv, scale, logit_softcap, v->type,
12379+
kq_vec_dot, v_to_float, VKQ16[i_gqa], VKQ32[i_gqa], V32[i_gqa], S+i_gqa, M+i_gqa);
1230012380
}
12381+
}
1230112382

12302-
s += mv; // apply mask
12303-
12304-
const float Mold = M;
12305-
12306-
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
12307-
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
12308-
12309-
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
12310-
12383+
for (int i = 0; i < n_gqa; ++i) {
1231112384
if (v->type == GGML_TYPE_F16) {
12312-
if (s > M) {
12313-
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
12314-
M = s;
12315-
ms = expf(Mold - M);
12316-
12317-
// V = V*expf(Mold - M)
12318-
ggml_vec_scale_f16(D, VKQ16, ms);
12319-
} else {
12320-
// no new maximum, ms == 1.0f, vs != 1.0f
12321-
vs = expf(s - M);
12322-
}
12323-
12324-
// V += v*expf(s - M)
12325-
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
12326-
} else {
12327-
if (s > M) {
12328-
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
12329-
M = s;
12330-
ms = expf(Mold - M);
12331-
12332-
// V = V*expf(Mold - M)
12333-
ggml_vec_scale_f32(D, VKQ32, ms);
12334-
} else {
12335-
// no new maximum, ms == 1.0f, vs != 1.0f
12336-
vs = expf(s - M);
12385+
for (int64_t d = 0; d < D; ++d) {
12386+
VKQ32[i][d] = GGML_FP16_TO_FP32(VKQ16[i][d]);
1233712387
}
12338-
12339-
v_to_float(v_data, V32, D);
12340-
12341-
// V += v*expf(s - M)
12342-
ggml_vec_mad_f32(D, VKQ32, V32, vs);
1234312388
}
1234412389

12345-
S = S*ms + vs; // scale and increment sum with partial sum
12346-
}
12390+
// V /= S
12391+
const float S_inv = 1.0f/S[i];
12392+
ggml_vec_scale_f32(D, VKQ32[i], S_inv);
1234712393

12348-
if (v->type == GGML_TYPE_F16) {
12349-
for (int64_t d = 0; d < D; ++d) {
12350-
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
12351-
}
12352-
}
12353-
12354-
// V /= S
12355-
const float S_inv = 1.0f/S;
12356-
ggml_vec_scale_f32(D, VKQ32, S_inv);
12357-
12358-
// dst indices
12359-
const int i1 = iq1;
12360-
const int i2 = iq2;
12361-
const int i3 = iq3;
12394+
// dst indices
12395+
const int i1 = iq1;
12396+
const int i2 = iq2 + i;
12397+
const int i3 = iq3;
1236212398

12363-
// original
12364-
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
12399+
// original
12400+
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
1236512401

12366-
// permute(0, 2, 1, 3)
12367-
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
12402+
// permute(0, 2, 1, 3)
12403+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32[i], nb1);
12404+
}
1236812405
}
1236912406
}
1237012407

@@ -15132,8 +15169,10 @@ struct ggml_cplan ggml_graph_plan(
1513215169
case GGML_OP_FLASH_ATTN_EXT:
1513315170
{
1513415171
const int64_t ne00 = node->src[0]->ne[0]; // D
15135-
15136-
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
15172+
const int64_t ne02 = node->src[0]->ne[2]; // n_head
15173+
const int64_t ne12 = node->src[1]->ne[2]; // n_head_kv
15174+
const int64_t n_gqa = ne02/ne12;
15175+
cur = 3*sizeof(float)*ne00*n_tasks*n_gqa; // 3x head size/thread
1513715176
} break;
1513815177
case GGML_OP_FLASH_ATTN_BACK:
1513915178
{

0 commit comments

Comments
 (0)