@@ -8473,8 +8473,12 @@ struct ggml_tensor * ggml_flash_attn_ext(
84738473 is_node = true;
84748474 }
84758475
8476+ // k*q will be { k->ne[1], q->ne[2], q->ne[1], q->ne[3] }
8477+ // v^T is { v->ne[1], v->ne[0], v->ne[2], v->ne[3] }
8478+ // => result is { v->ne[0], q->ne[2], q->ne[1], q->ne[3] }
84768479 // permute(0, 2, 1, 3)
8477- int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
8480+ //int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
8481+ int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
84788482 struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
84798483
84808484 float params[] = { scale, max_bias, softcap };
@@ -17436,23 +17440,24 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1743617440 const int ith = params->ith;
1743717441 const int nth = params->nth;
1743817442
17439- const int64_t D = neq0;
17440- const int64_t N = neq1;
17443+ const int64_t Dk = nek0;
17444+ const int64_t Dv = nev0;
17445+ const int64_t N = neq1;
1744117446
17442- GGML_ASSERT(ne0 == D );
17447+ GGML_ASSERT(ne0 == Dv );
1744317448 GGML_ASSERT(ne2 == N);
1744417449
1744517450 // input tensor rows must be contiguous
1744617451 GGML_ASSERT(nbq0 == ggml_type_size(q->type));
1744717452 GGML_ASSERT(nbk0 == ggml_type_size(k->type));
1744817453 GGML_ASSERT(nbv0 == ggml_type_size(v->type));
1744917454
17450- GGML_ASSERT(neq0 == D );
17451- GGML_ASSERT(nek0 == D );
17452- GGML_ASSERT(nev0 == D );
17455+ GGML_ASSERT(neq0 == Dk );
17456+ GGML_ASSERT(nek0 == Dk );
17457+ GGML_ASSERT(nev0 == Dv );
1745317458
1745417459 GGML_ASSERT(neq1 == N);
17455- GGML_ASSERT(nev0 == D );
17460+ GGML_ASSERT(nev0 == Dv );
1745617461
1745717462 // dst cannot be transposed or permuted
1745817463 GGML_ASSERT(nb0 == sizeof(float));
@@ -17516,7 +17521,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1751617521 int iq1 = (ith%ntg)*neq1g;
1751717522 int this_neq1 = MIN(neq1g, neq1-iq1);
1751817523 if (!iqk_flash_attn_noalibi(k->type, v->type,
17519- D , this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
17524+ Dk, Dv , this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
1752017525 (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
1752117526 (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
1752217527 (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
@@ -17543,6 +17548,8 @@ IQK_Flash_Attn_NotAvailable:;
1754317548 ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
1754417549 ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
1754517550
17551+ const int64_t Dkv = MAX(Dk, Dv);
17552+
1754617553 // loop over n_batch and n_head
1754717554 for (int ir = ir0; ir < ir1; ++ir) {
1754817555 // q indices
@@ -17556,15 +17563,15 @@ IQK_Flash_Attn_NotAvailable:;
1755617563 float S = 0.0f; // sum
1755717564 float M = -INFINITY; // maximum KQ value
1755817565
17559- float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
17560- float * V32 = (VKQ32 + 1*D ); // (temporary) FP32 V buffer
17561- ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D ); // (temporary) FP16 VKQ accumulator
17562- ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D ); // (temporary) buffer for Q converted to quantized/FP16
17566+ float * VKQ32 = (float *) params->wdata + ith*(3*Dkv + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
17567+ float * V32 = (VKQ32 + 1*Dkv ); // (temporary) FP32 V buffer
17568+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*Dkv ); // (temporary) FP16 VKQ accumulator
17569+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*Dkv ); // (temporary) buffer for Q converted to quantized/FP16
1756317570
1756417571 if (v->type == GGML_TYPE_F16) {
17565- memset(VKQ16, 0, D *sizeof(ggml_fp16_t));
17572+ memset(VKQ16, 0, Dkv *sizeof(ggml_fp16_t));
1756617573 } else {
17567- memset(VKQ32, 0, D *sizeof(float));
17574+ memset(VKQ32, 0, Dkv *sizeof(float));
1756817575 }
1756917576
1757017577 const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
@@ -17578,7 +17585,7 @@ IQK_Flash_Attn_NotAvailable:;
1757817585 const int iv2 = iq2 / rv2;
1757917586
1758017587 const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
17581- q_to_vec_dot(pq, Q_q, D );
17588+ q_to_vec_dot(pq, Q_q, Dk );
1758217589
1758317590 // online softmax / attention
1758417591 // loop over n_kv and n_head_kv
@@ -17592,7 +17599,7 @@ IQK_Flash_Attn_NotAvailable:;
1759217599 float s; // KQ value
1759317600
1759417601 const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
17595- kq_vec_dot(D , &s, 0, k_data, 0, Q_q, 0, 1);
17602+ kq_vec_dot(Dk , &s, 0, k_data, 0, Q_q, 0, 1);
1759617603
1759717604 s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask
1759817605
@@ -17610,45 +17617,45 @@ IQK_Flash_Attn_NotAvailable:;
1761017617 ms = expf(Mold - M);
1761117618
1761217619 // V = V*expf(Mold - M)
17613- ggml_vec_scale_f16(D , VKQ16, ms);
17620+ ggml_vec_scale_f16(Dv , VKQ16, ms);
1761417621 } else {
1761517622 // no new maximum, ms == 1.0f, vs != 1.0f
1761617623 vs = expf(s - M);
1761717624 }
1761817625
1761917626 // V += v*expf(s - M)
17620- ggml_vec_mad_f16(D , VKQ16, (const ggml_fp16_t *) v_data, vs);
17627+ ggml_vec_mad_f16(Dv , VKQ16, (const ggml_fp16_t *) v_data, vs);
1762117628 } else {
1762217629 if (s > M) {
1762317630 // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
1762417631 M = s;
1762517632 ms = expf(Mold - M);
1762617633
1762717634 // V = V*expf(Mold - M)
17628- ggml_vec_scale_f32(D , VKQ32, ms);
17635+ ggml_vec_scale_f32(Dv , VKQ32, ms);
1762917636 } else {
1763017637 // no new maximum, ms == 1.0f, vs != 1.0f
1763117638 vs = expf(s - M);
1763217639 }
1763317640
17634- v_to_float(v_data, V32, D );
17641+ v_to_float(v_data, V32, Dv );
1763517642
1763617643 // V += v*expf(s - M)
17637- ggml_vec_mad_f32(D , VKQ32, V32, vs);
17644+ ggml_vec_mad_f32(Dv , VKQ32, V32, vs);
1763817645 }
1763917646
1764017647 S = S*ms + vs; // scale and increment sum with partial sum
1764117648 }
1764217649
1764317650 if (v->type == GGML_TYPE_F16) {
17644- for (int64_t d = 0; d < D ; ++d) {
17651+ for (int64_t d = 0; d < Dv ; ++d) {
1764517652 VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
1764617653 }
1764717654 }
1764817655
1764917656 // V /= S
1765017657 const float S_inv = 1.0f/S;
17651- ggml_vec_scale_f32(D , VKQ32, S_inv);
17658+ ggml_vec_scale_f32(Dv , VKQ32, S_inv);
1765217659
1765317660 // dst indices
1765417661 const int i1 = iq1;
@@ -21112,9 +21119,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
2111221119 } break;
2111321120 case GGML_OP_FLASH_ATTN_EXT:
2111421121 {
21115- const int64_t ne00 = node->src[0]->ne[0]; // D
21122+ const int64_t Dk = node->src[0]->ne[0];
21123+ const int64_t Dv = node->src[2]->ne[0];
21124+ const int64_t D = MAX(Dk, Dv);
2111621125
21117- cur = 3*sizeof(float)*ne00 *n_tasks; // 3x head size/thread
21126+ cur = 3*sizeof(float)*D *n_tasks; // 3x head size/thread
2111821127 } break;
2111921128 case GGML_OP_FLASH_ATTN_BACK:
2112021129 {
0 commit comments