@@ -12238,23 +12238,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1223812238 const int ith = params->ith;
1223912239 const int nth = params->nth;
1224012240
12241- const int64_t D = neq0;
12242- const int64_t N = neq1;
12241+ const int64_t DK = nek0;
12242+ const int64_t DV = nev0;
12243+ const int64_t N = neq1;
1224312244
12244- GGML_ASSERT(ne0 == D );
12245+ GGML_ASSERT(ne0 == DV );
1224512246 GGML_ASSERT(ne2 == N);
1224612247
1224712248 // input tensor rows must be contiguous
1224812249 GGML_ASSERT(nbq0 == ggml_type_size(q->type));
1224912250 GGML_ASSERT(nbk0 == ggml_type_size(k->type));
1225012251 GGML_ASSERT(nbv0 == ggml_type_size(v->type));
1225112252
12252- GGML_ASSERT(neq0 == D );
12253- GGML_ASSERT(nek0 == D );
12254- GGML_ASSERT(nev0 == D );
12253+ GGML_ASSERT(neq0 == DK );
12254+ GGML_ASSERT(nek0 == DK );
12255+ GGML_ASSERT(nev0 == DV );
1225512256
1225612257 GGML_ASSERT(neq1 == N);
12257- GGML_ASSERT(nev0 == D);
1225812258
1225912259 // dst cannot be transposed or permuted
1226012260 GGML_ASSERT(nb0 == sizeof(float));
@@ -12320,15 +12320,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1232012320 float S = 0.0f; // sum
1232112321 float M = -INFINITY; // maximum KQ value
1232212322
12323- float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
12324- float * V32 = (VKQ32 + 1*D ); // (temporary) FP32 V buffer
12325- ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D ); // (temporary) FP16 VKQ accumulator
12326- ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D ); // (temporary) buffer for Q converted to quantized/FP16
12323+ float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
12324+ float * V32 = (VKQ32 + 1*DV ); // (temporary) FP32 V buffer
12325+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV ); // (temporary) FP16 VKQ accumulator
12326+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV ); // (temporary) buffer for Q converted to quantized/FP16
1232712327
1232812328 if (v->type == GGML_TYPE_F16) {
12329- memset(VKQ16, 0, D *sizeof(ggml_fp16_t));
12329+ memset(VKQ16, 0, DV *sizeof(ggml_fp16_t));
1233012330 } else {
12331- memset(VKQ32, 0, D *sizeof(float));
12331+ memset(VKQ32, 0, DV *sizeof(float));
1233212332 }
1233312333
1233412334 const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
@@ -12342,7 +12342,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1234212342 const int iv2 = iq2 / rv2;
1234312343
1234412344 const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
12345- q_to_vec_dot(pq, Q_q, D );
12345+ q_to_vec_dot(pq, Q_q, DK );
1234612346
1234712347 // online softmax / attention
1234812348 // loop over n_kv and n_head_kv
@@ -12356,7 +12356,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1235612356 float s; // KQ value
1235712357
1235812358 const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
12359- kq_vec_dot(D , &s, 0, k_data, 0, Q_q, 0, 1);
12359+ kq_vec_dot(DK , &s, 0, k_data, 0, Q_q, 0, 1);
1236012360
1236112361 s = s*scale; // scale KQ value
1236212362
@@ -12380,45 +12380,45 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1238012380 ms = expf(Mold - M);
1238112381
1238212382 // V = V*expf(Mold - M)
12383- ggml_vec_scale_f16(D , VKQ16, ms);
12383+ ggml_vec_scale_f16(DV , VKQ16, ms);
1238412384 } else {
1238512385 // no new maximum, ms == 1.0f, vs != 1.0f
1238612386 vs = expf(s - M);
1238712387 }
1238812388
1238912389 // V += v*expf(s - M)
12390- ggml_vec_mad_f16(D , VKQ16, (const ggml_fp16_t *) v_data, vs);
12390+ ggml_vec_mad_f16(DV , VKQ16, (const ggml_fp16_t *) v_data, vs);
1239112391 } else {
1239212392 if (s > M) {
1239312393 // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
1239412394 M = s;
1239512395 ms = expf(Mold - M);
1239612396
1239712397 // V = V*expf(Mold - M)
12398- ggml_vec_scale_f32(D , VKQ32, ms);
12398+ ggml_vec_scale_f32(DV , VKQ32, ms);
1239912399 } else {
1240012400 // no new maximum, ms == 1.0f, vs != 1.0f
1240112401 vs = expf(s - M);
1240212402 }
1240312403
12404- v_to_float(v_data, V32, D );
12404+ v_to_float(v_data, V32, DV );
1240512405
1240612406 // V += v*expf(s - M)
12407- ggml_vec_mad_f32(D , VKQ32, V32, vs);
12407+ ggml_vec_mad_f32(DV , VKQ32, V32, vs);
1240812408 }
1240912409
1241012410 S = S*ms + vs; // scale and increment sum with partial sum
1241112411 }
1241212412
1241312413 if (v->type == GGML_TYPE_F16) {
12414- for (int64_t d = 0; d < D ; ++d) {
12414+ for (int64_t d = 0; d < DV ; ++d) {
1241512415 VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
1241612416 }
1241712417 }
1241812418
1241912419 // V /= S
1242012420 const float S_inv = 1.0f/S;
12421- ggml_vec_scale_f32(D , VKQ32, S_inv);
12421+ ggml_vec_scale_f32(DV , VKQ32, S_inv);
1242212422
1242312423 // dst indices
1242412424 const int i1 = iq1;
@@ -15277,7 +15277,6 @@ struct ggml_cplan ggml_graph_plan(
1527715277 size_t cur = 0;
1527815278
1527915279 if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
15280-
1528115280 switch (node->op) {
1528215281 case GGML_OP_CPY:
1528315282 case GGML_OP_DUP:
@@ -15386,9 +15385,10 @@ struct ggml_cplan ggml_graph_plan(
1538615385 } break;
1538715386 case GGML_OP_FLASH_ATTN_EXT:
1538815387 {
15389- const int64_t ne00 = node->src[0]->ne[0]; // D
15388+ const int64_t ne10 = node->src[1]->ne[0]; // DK
15389+ const int64_t ne20 = node->src[2]->ne[0]; // DV
1539015390
15391- cur = 3* sizeof(float)*ne00* n_tasks; // 3x head size/ thread
15391+ cur = sizeof(float)*(1*ne10 + 2*ne20)* n_tasks; // 1x head size K + 2x head size V (per thread)
1539215392 } break;
1539315393 case GGML_OP_FLASH_ATTN_BACK:
1539415394 {
0 commit comments