Skip to content

Commit 3c98bfb

Browse files
ikawrakowIwan Kawrakow
andauthored
DeepSeek FA support (CPU only) (#200)
* Adding support for K head size != V head size This is relevant for DeepSeek models. At this point ggml CPU FA works. Now I need to go and change iqk FA to make it work with Dk != Dv. * iqk support for K head size != V head size To not have compilation time explode, just Dk = 192, Dv = 128 for now (DeepSeek) * FA: very slightly faster for nq = 1 (TG) --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent a366a3d commit 3c98bfb

File tree

4 files changed

+221
-129
lines changed

4 files changed

+221
-129
lines changed

ggml/src/ggml.c

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8473,8 +8473,12 @@ struct ggml_tensor * ggml_flash_attn_ext(
84738473
is_node = true;
84748474
}
84758475

8476+
// k*q will be { k->ne[1], q->ne[2], q->ne[1], q->ne[3] }
8477+
// v^T is { v->ne[1], v->ne[0], v->ne[2], v->ne[3] }
8478+
// => result is { v->ne[0], q->ne[2], q->ne[1], q->ne[3] }
84768479
// permute(0, 2, 1, 3)
8477-
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
8480+
//int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
8481+
int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
84788482
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
84798483

84808484
float params[] = { scale, max_bias, softcap };
@@ -17436,23 +17440,24 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1743617440
const int ith = params->ith;
1743717441
const int nth = params->nth;
1743817442

17439-
const int64_t D = neq0;
17440-
const int64_t N = neq1;
17443+
const int64_t Dk = nek0;
17444+
const int64_t Dv = nev0;
17445+
const int64_t N = neq1;
1744117446

17442-
GGML_ASSERT(ne0 == D);
17447+
GGML_ASSERT(ne0 == Dv);
1744317448
GGML_ASSERT(ne2 == N);
1744417449

1744517450
// input tensor rows must be contiguous
1744617451
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
1744717452
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
1744817453
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
1744917454

17450-
GGML_ASSERT(neq0 == D);
17451-
GGML_ASSERT(nek0 == D);
17452-
GGML_ASSERT(nev0 == D);
17455+
GGML_ASSERT(neq0 == Dk);
17456+
GGML_ASSERT(nek0 == Dk);
17457+
GGML_ASSERT(nev0 == Dv);
1745317458

1745417459
GGML_ASSERT(neq1 == N);
17455-
GGML_ASSERT(nev0 == D);
17460+
GGML_ASSERT(nev0 == Dv);
1745617461

1745717462
// dst cannot be transposed or permuted
1745817463
GGML_ASSERT(nb0 == sizeof(float));
@@ -17516,7 +17521,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1751617521
int iq1 = (ith%ntg)*neq1g;
1751717522
int this_neq1 = MIN(neq1g, neq1-iq1);
1751817523
if (!iqk_flash_attn_noalibi(k->type, v->type,
17519-
D, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
17524+
Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
1752017525
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
1752117526
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
1752217527
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
@@ -17543,6 +17548,8 @@ IQK_Flash_Attn_NotAvailable:;
1754317548
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
1754417549
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
1754517550

17551+
const int64_t Dkv = MAX(Dk, Dv);
17552+
1754617553
// loop over n_batch and n_head
1754717554
for (int ir = ir0; ir < ir1; ++ir) {
1754817555
// q indices
@@ -17556,15 +17563,15 @@ IQK_Flash_Attn_NotAvailable:;
1755617563
float S = 0.0f; // sum
1755717564
float M = -INFINITY; // maximum KQ value
1755817565

17559-
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
17560-
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
17561-
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
17562-
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
17566+
float * VKQ32 = (float *) params->wdata + ith*(3*Dkv + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
17567+
float * V32 = (VKQ32 + 1*Dkv); // (temporary) FP32 V buffer
17568+
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*Dkv); // (temporary) FP16 VKQ accumulator
17569+
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*Dkv); // (temporary) buffer for Q converted to quantized/FP16
1756317570

1756417571
if (v->type == GGML_TYPE_F16) {
17565-
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
17572+
memset(VKQ16, 0, Dkv*sizeof(ggml_fp16_t));
1756617573
} else {
17567-
memset(VKQ32, 0, D*sizeof(float));
17574+
memset(VKQ32, 0, Dkv*sizeof(float));
1756817575
}
1756917576

1757017577
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
@@ -17578,7 +17585,7 @@ IQK_Flash_Attn_NotAvailable:;
1757817585
const int iv2 = iq2 / rv2;
1757917586

1758017587
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
17581-
q_to_vec_dot(pq, Q_q, D);
17588+
q_to_vec_dot(pq, Q_q, Dk);
1758217589

1758317590
// online softmax / attention
1758417591
// loop over n_kv and n_head_kv
@@ -17592,7 +17599,7 @@ IQK_Flash_Attn_NotAvailable:;
1759217599
float s; // KQ value
1759317600

1759417601
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
17595-
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
17602+
kq_vec_dot(Dk, &s, 0, k_data, 0, Q_q, 0, 1);
1759617603

1759717604
s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask
1759817605

@@ -17610,45 +17617,45 @@ IQK_Flash_Attn_NotAvailable:;
1761017617
ms = expf(Mold - M);
1761117618

1761217619
// V = V*expf(Mold - M)
17613-
ggml_vec_scale_f16(D, VKQ16, ms);
17620+
ggml_vec_scale_f16(Dv, VKQ16, ms);
1761417621
} else {
1761517622
// no new maximum, ms == 1.0f, vs != 1.0f
1761617623
vs = expf(s - M);
1761717624
}
1761817625

1761917626
// V += v*expf(s - M)
17620-
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
17627+
ggml_vec_mad_f16(Dv, VKQ16, (const ggml_fp16_t *) v_data, vs);
1762117628
} else {
1762217629
if (s > M) {
1762317630
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
1762417631
M = s;
1762517632
ms = expf(Mold - M);
1762617633

1762717634
// V = V*expf(Mold - M)
17628-
ggml_vec_scale_f32(D, VKQ32, ms);
17635+
ggml_vec_scale_f32(Dv, VKQ32, ms);
1762917636
} else {
1763017637
// no new maximum, ms == 1.0f, vs != 1.0f
1763117638
vs = expf(s - M);
1763217639
}
1763317640

17634-
v_to_float(v_data, V32, D);
17641+
v_to_float(v_data, V32, Dv);
1763517642

1763617643
// V += v*expf(s - M)
17637-
ggml_vec_mad_f32(D, VKQ32, V32, vs);
17644+
ggml_vec_mad_f32(Dv, VKQ32, V32, vs);
1763817645
}
1763917646

1764017647
S = S*ms + vs; // scale and increment sum with partial sum
1764117648
}
1764217649

1764317650
if (v->type == GGML_TYPE_F16) {
17644-
for (int64_t d = 0; d < D; ++d) {
17651+
for (int64_t d = 0; d < Dv; ++d) {
1764517652
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
1764617653
}
1764717654
}
1764817655

1764917656
// V /= S
1765017657
const float S_inv = 1.0f/S;
17651-
ggml_vec_scale_f32(D, VKQ32, S_inv);
17658+
ggml_vec_scale_f32(Dv, VKQ32, S_inv);
1765217659

1765317660
// dst indices
1765417661
const int i1 = iq1;
@@ -21112,9 +21119,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
2111221119
} break;
2111321120
case GGML_OP_FLASH_ATTN_EXT:
2111421121
{
21115-
const int64_t ne00 = node->src[0]->ne[0]; // D
21122+
const int64_t Dk = node->src[0]->ne[0];
21123+
const int64_t Dv = node->src[2]->ne[0];
21124+
const int64_t D = MAX(Dk, Dv);
2111621125

21117-
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
21126+
cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread
2111821127
} break;
2111921128
case GGML_OP_FLASH_ATTN_BACK:
2112021129
{

0 commit comments

Comments
 (0)