@@ -12152,7 +12152,77 @@ static void ggml_compute_forward_argsort(
1215212152}
1215312153
1215412154// ggml_compute_forward_flash_attn_ext
12155+ static inline void ggml_compute_forward_flash_attn_ext_f16_one_QKV(
12156+         const ggml_fp16_t *Q,
12157+         const char *K,
12158+         const char *V,
12159+         const int64_t D,
12160+         const float mask_value,
12161+         const float scale,
12162+         const float logit_softcap,
12163+         const enum ggml_type v_type,
12164+         ggml_vec_dot_t    const kq_vec_dot,
12165+         ggml_to_float_t   const v_to_float,
12166+         ggml_fp16_t *VKQ16,
12167+         float *VKQ32,
12168+         float *V32,
12169+         float *sum,
12170+         float *max_kq_value) {
12171+     float s; // KQ value
12172+     kq_vec_dot(D, &s, 0, K, 0, Q, 0, 1);
12173+ 
12174+     s = s*scale; // scale KQ value
12175+ 
12176+     if (logit_softcap != 0.0f) {
12177+         s = logit_softcap*tanhf(s);
12178+     }
12179+     s += mask_value; // apply mask
12180+     float M = *max_kq_value;
12181+     const float Mold = M;
12182+ 
12183+     float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
12184+     float vs = 1.0f; // post-softmax KQ value, expf(s - M)
12185+ 
12186+     if (v_type == GGML_TYPE_F16) {
12187+         if (s > M) {
12188+             // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
12189+             M = s;
12190+             ms = expf(Mold - M);
12191+ 
12192+             // V = V*expf(Mold - M)
12193+             ggml_vec_scale_f16(D, VKQ16, ms);
12194+         } else {
12195+             // no new maximum, ms == 1.0f, vs != 1.0f
12196+             vs = expf(s - M);
12197+         }
12198+ 
12199+         // V += v*expf(s - M)
12200+         ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) V, vs);
12201+     } else {
12202+         if (s > M) {
12203+             // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
12204+             M = s;
12205+             ms = expf(Mold - M);
12206+ 
12207+             // V = V*expf(Mold - M)
12208+             ggml_vec_scale_f32(D, VKQ32, ms);
12209+         } else {
12210+             // no new maximum, ms == 1.0f, vs != 1.0f
12211+             vs = expf(s - M);
12212+         }
1215512213
12214+         v_to_float(V, V32, D);
12215+ 
12216+         // V += v*expf(s - M)
12217+         ggml_vec_mad_f32(D, VKQ32, V32, vs);
12218+     }
12219+     float S = *sum;
12220+     S = S*ms + vs; // scale and increment sum with partial sum
12221+     *sum = S;
12222+     *max_kq_value = M;
12223+ }
12224+ 
12225+ #define GGML_FLASH_ATTN_EXT_MAX_GQA 16
1215612226static void ggml_compute_forward_flash_attn_ext_f16(
1215712227        const struct ggml_compute_params * params,
1215812228        const struct ggml_tensor * q,
@@ -12179,6 +12249,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1217912249    GGML_ASSERT(ne0 == D);
1218012250    GGML_ASSERT(ne2 == N);
1218112251
12252+     const int n_gqa = neq2 / nek2;
12253+     GGML_ASSERT(n_gqa <= GGML_FLASH_ATTN_EXT_MAX_GQA);
1218212254    // input tensor rows must be contiguous
1218312255    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
1218412256    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
@@ -12206,15 +12278,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1220612278
1220712279    // parallelize by q rows using ggml_vec_dot_f32
1220812280
12209-     // total rows  in q
12210-     const int nr  = neq1*neq2*neq3;
12281+     // total groups  in q
12282+     const int ng  = neq1*neq2*neq3/n_gqa ;
1221112283
12212-     // rows  per thread
12213-     const int dr  = (nr  + nth - 1)/nth;
12284+     // groups  per thread
12285+     const int dg  = (ng  + nth - 1)/nth;
1221412286
12215-     // row  range for this thread
12216-     const int ir0  = dr *ith;
12217-     const int ir1  = MIN(ir0  + dr, nr );
12287+     // group  range for this thread
12288+     const int ig0  = dg *ith;
12289+     const int ig1  = MIN(ig0  + dg, ng );
1221812290
1221912291    float scale         = 1.0f;
1222012292    float max_bias      = 0.0f;
@@ -12242,28 +12314,42 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1224212314    GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
1224312315    GGML_ASSERT(v_to_float   && "fattn: unsupported V-type");
1224412316
12245-     // loop over n_batch and n_head
12246-     for (int ir = ir0; ir < ir1; ++ir) {
12317+     float S[GGML_FLASH_ATTN_EXT_MAX_GQA]; // sum
12318+     float M[GGML_FLASH_ATTN_EXT_MAX_GQA]; // maximum KQ value
12319+     float       * VKQ32[GGML_FLASH_ATTN_EXT_MAX_GQA]; // FP32 VKQ accumulator
12320+     float       * V32[GGML_FLASH_ATTN_EXT_MAX_GQA];   // (temporary) FP32 V buffer
12321+     ggml_fp16_t * VKQ16[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) FP16 VKQ accumulator
12322+     ggml_fp16_t * Q_q[GGML_FLASH_ATTN_EXT_MAX_GQA];   // (temporary) buffer for Q converted to quantized/FP16
12323+     float slope[GGML_FLASH_ATTN_EXT_MAX_GQA];
12324+ 
12325+     // loop over n_batch and n_group
12326+     for (int ig = ig0; ig < ig1; ++ig) {
12327+         const int group_index = ig % ng;
12328+         const int batch_index = ig / ng;
1224712329        // q indices
12248-         const int iq3 = ir/(neq2*neq1);
12249-         const int iq2 = (ir - iq3*neq2*neq1)/neq1;
12250-         const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
12330+         const int iq3 = 0;
12331+         const int iq2 = group_index * n_gqa; // start head index
12332+         const int iq1 = batch_index;
12333+ 
12334+         for (int i_gqa = 0; i_gqa < n_gqa; ++i_gqa) {
12335+             S[i_gqa] = 0.0f;
12336+             M[i_gqa] = -INFINITY;
12337+             VKQ32 [i_gqa] = (float *) params->wdata + ith*(3*n_gqa*D + CACHE_LINE_SIZE_F32) + 3*i_gqa*D;
12338+             V32   [i_gqa] = (VKQ32[i_gqa] + 1*D);
12339+             VKQ16 [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 1*D);
12340+             Q_q   [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 2*D);
1225112341
12252-         const uint32_t h = iq2; // head index
12253-         const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
12254- 
12255-         float S = 0.0f;      // sum
12256-         float M = -INFINITY; // maximum KQ value
12257- 
12258-         float       * VKQ32 = (float       *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
12259-         float       * V32   =                 (VKQ32 + 1*D); // (temporary) FP32 V buffer
12260-         ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
12261-         ggml_fp16_t * Q_q   = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
12342+             if (v->type == GGML_TYPE_F16) {
12343+                 memset(VKQ16[i_gqa], 0, 1*D*sizeof(ggml_fp16_t));
12344+             } else {
12345+                 memset(VKQ32[i_gqa], 0, 1*D*sizeof(float));
12346+             }
1226212347
12263-         if (v->type == GGML_TYPE_F16) {
12264-             memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
12265-         } else {
12266-             memset(VKQ32, 0, D*sizeof(float));
12348+             const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + (iq2 + i_gqa)*nbq2 + iq3*nbq3));
12349+             q_to_vec_dot(pq, Q_q[i_gqa], D);
12350+             
12351+             const uint32_t h = iq2 + i_gqa;
12352+             slope[i_gqa] = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
1226712353        }
1226812354
1226912355        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
@@ -12276,95 +12362,46 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1227612362        const int iv3 = iq3 / rv3;
1227712363        const int iv2 = iq2 / rv2;
1227812364
12279-         const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
12280-         q_to_vec_dot(pq, Q_q, D);
12281- 
1228212365        // online softmax / attention
1228312366        // loop over n_kv and n_head_kv
1228412367        // ref: https://arxiv.org/pdf/2112.05682.pdf
1228512368        for (int64_t ic = 0; ic < nek1; ++ic) {
12286-             const float mv  = mp ? slope* GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
12287-             if (mv  == -INFINITY) {
12369+             const float mp_value_base  = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
12370+             if (mp_value_base  == -INFINITY) {
1228812371                continue;
1228912372            }
12290- 
12291-             float s; // KQ value
12292- 
12373+             const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
1229312374            const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
12294-             kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
12295- 
12296-             s = s*scale; // scale KQ value
12297- 
12298-             if (logit_softcap != 0.0f) {
12299-                 s = logit_softcap*tanhf(s);
12375+             for (int i_gqa = 0; i_gqa < n_gqa; ++i_gqa) {
12376+                 const float mv = mp_value_base * slope[i_gqa];
12377+                 ggml_compute_forward_flash_attn_ext_f16_one_QKV(
12378+                     Q_q[i_gqa], k_data, v_data, D, mv, scale, logit_softcap, v->type, 
12379+                     kq_vec_dot, v_to_float, VKQ16[i_gqa], VKQ32[i_gqa], V32[i_gqa], S+i_gqa, M+i_gqa);
1230012380            }
12381+         }
1230112382
12302-             s += mv; // apply mask
12303- 
12304-             const float Mold = M;
12305- 
12306-             float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
12307-             float vs = 1.0f; // post-softmax KQ value, expf(s - M)
12308- 
12309-             const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
12310- 
12383+         for (int i = 0; i < n_gqa; ++i) {
1231112384            if (v->type == GGML_TYPE_F16) {
12312-                 if (s > M) {
12313-                     // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
12314-                     M = s;
12315-                     ms = expf(Mold - M);
12316- 
12317-                     // V = V*expf(Mold - M)
12318-                     ggml_vec_scale_f16(D, VKQ16, ms);
12319-                 } else {
12320-                     // no new maximum, ms == 1.0f, vs != 1.0f
12321-                     vs = expf(s - M);
12322-                 }
12323- 
12324-                 // V += v*expf(s - M)
12325-                 ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
12326-             } else {
12327-                 if (s > M) {
12328-                     // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
12329-                     M = s;
12330-                     ms = expf(Mold - M);
12331- 
12332-                     // V = V*expf(Mold - M)
12333-                     ggml_vec_scale_f32(D, VKQ32, ms);
12334-                 } else {
12335-                     // no new maximum, ms == 1.0f, vs != 1.0f
12336-                     vs = expf(s - M);
12385+                 for (int64_t d = 0; d < D; ++d) {
12386+                     VKQ32[i][d] = GGML_FP16_TO_FP32(VKQ16[i][d]);
1233712387                }
12338- 
12339-                 v_to_float(v_data, V32, D);
12340- 
12341-                 // V += v*expf(s - M)
12342-                 ggml_vec_mad_f32(D, VKQ32, V32, vs);
1234312388            }
1234412389
12345-             S = S*ms + vs; // scale and increment sum with partial sum
12346-         }
12390+             // V /= S
12391+             const float S_inv = 1.0f/S[i];
12392+             ggml_vec_scale_f32(D, VKQ32[i], S_inv);
1234712393
12348-         if (v->type == GGML_TYPE_F16) {
12349-             for (int64_t d = 0; d < D; ++d) {
12350-                 VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
12351-             }
12352-         }
12353- 
12354-         // V /= S
12355-         const float S_inv = 1.0f/S;
12356-         ggml_vec_scale_f32(D, VKQ32, S_inv);
12357- 
12358-         // dst indices
12359-         const int i1 = iq1;
12360-         const int i2 = iq2;
12361-         const int i3 = iq3;
12394+             // dst indices
12395+             const int i1 = iq1;
12396+             const int i2 = iq2 + i;
12397+             const int i3 = iq3;
1236212398
12363-         // original
12364-         //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
12399+              // original
12400+              //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
1236512401
12366-         // permute(0, 2, 1, 3)
12367-         memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
12402+             // permute(0, 2, 1, 3)
12403+             memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32[i], nb1);
12404+         }
1236812405    }
1236912406}
1237012407
@@ -15132,8 +15169,10 @@ struct ggml_cplan ggml_graph_plan(
1513215169                case GGML_OP_FLASH_ATTN_EXT:
1513315170                    {
1513415171                        const int64_t ne00 = node->src[0]->ne[0]; // D
15135- 
15136-                         cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
15172+                         const int64_t ne02 = node->src[0]->ne[2]; // n_head
15173+                         const int64_t ne12 = node->src[1]->ne[2]; // n_head_kv
15174+                         const int64_t n_gqa = ne02/ne12;
15175+                         cur = 3*sizeof(float)*ne00*n_tasks*n_gqa; // 3x head size/thread
1513715176                    } break;
1513815177                case GGML_OP_FLASH_ATTN_BACK:
1513915178                    {
0 commit comments