Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 140 additions & 101 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -12152,7 +12152,77 @@ static void ggml_compute_forward_argsort(
}

// ggml_compute_forward_flash_attn_ext
static inline void ggml_compute_forward_flash_attn_ext_f16_one_QKV(
const ggml_fp16_t *Q,
const char *K,
const char *V,
const int64_t D,
const float mask_value,
const float scale,
const float logit_softcap,
const enum ggml_type v_type,
ggml_vec_dot_t const kq_vec_dot,
ggml_to_float_t const v_to_float,
ggml_fp16_t *VKQ16,
float *VKQ32,
float *V32,
float *sum,
float *max_kq_value) {
float s; // KQ value
kq_vec_dot(D, &s, 0, K, 0, Q, 0, 1);

s = s*scale; // scale KQ value

if (logit_softcap != 0.0f) {
s = logit_softcap*tanhf(s);
}
s += mask_value; // apply mask
float M = *max_kq_value;
const float Mold = M;

float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
float vs = 1.0f; // post-softmax KQ value, expf(s - M)

if (v_type == GGML_TYPE_F16) {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);

// V = V*expf(Mold - M)
ggml_vec_scale_f16(D, VKQ16, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}

// V += v*expf(s - M)
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) V, vs);
} else {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);

// V = V*expf(Mold - M)
ggml_vec_scale_f32(D, VKQ32, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}

v_to_float(V, V32, D);

// V += v*expf(s - M)
ggml_vec_mad_f32(D, VKQ32, V32, vs);
}
float S = *sum;
S = S*ms + vs; // scale and increment sum with partial sum
*sum = S;
*max_kq_value = M;
}

#define GGML_FLASH_ATTN_EXT_MAX_GQA 16
static void ggml_compute_forward_flash_attn_ext_f16(
const struct ggml_compute_params * params,
const struct ggml_tensor * q,
Expand All @@ -12179,6 +12249,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(ne0 == D);
GGML_ASSERT(ne2 == N);

const int n_gqa = neq2 / nek2;
GGML_ASSERT(n_gqa <= GGML_FLASH_ATTN_EXT_MAX_GQA);
// input tensor rows must be contiguous
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
Expand Down Expand Up @@ -12206,15 +12278,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(

// parallelize by q rows using ggml_vec_dot_f32

// total rows in q
const int nr = neq1*neq2*neq3;
// total groups in q
const int ng = neq1*neq2*neq3/n_gqa;

// rows per thread
const int dr = (nr + nth - 1)/nth;
// groups per thread
const int dg = (ng + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
// group range for this thread
const int ig0 = dg*ith;
const int ig1 = MIN(ig0 + dg, ng);

float scale = 1.0f;
float max_bias = 0.0f;
Expand Down Expand Up @@ -12242,28 +12314,42 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");

// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
float S[GGML_FLASH_ATTN_EXT_MAX_GQA]; // sum
float M[GGML_FLASH_ATTN_EXT_MAX_GQA]; // maximum KQ value
float * VKQ32[GGML_FLASH_ATTN_EXT_MAX_GQA]; // FP32 VKQ accumulator
float * V32[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) buffer for Q converted to quantized/FP16
float slope[GGML_FLASH_ATTN_EXT_MAX_GQA];

// loop over n_batch and n_group
for (int ig = ig0; ig < ig1; ++ig) {
const int group_index = ig % ng;
const int batch_index = ig / ng;
// q indices
const int iq3 = ir/(neq2*neq1);
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
const int iq3 = 0;
const int iq2 = group_index * n_gqa; // start head index
const int iq1 = batch_index;

for (int i_gqa = 0; i_gqa < n_gqa; ++i_gqa) {
S[i_gqa] = 0.0f;
M[i_gqa] = -INFINITY;
VKQ32 [i_gqa] = (float *) params->wdata + ith*(3*n_gqa*D + CACHE_LINE_SIZE_F32) + 3*i_gqa*D;
V32 [i_gqa] = (VKQ32[i_gqa] + 1*D);
VKQ16 [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 1*D);
Q_q [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 2*D);

const uint32_t h = iq2; // head index
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;

float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value

float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
if (v->type == GGML_TYPE_F16) {
memset(VKQ16[i_gqa], 0, 1*D*sizeof(ggml_fp16_t));
} else {
memset(VKQ32[i_gqa], 0, 1*D*sizeof(float));
}

if (v->type == GGML_TYPE_F16) {
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
} else {
memset(VKQ32, 0, D*sizeof(float));
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + (iq2 + i_gqa)*nbq2 + iq3*nbq3));
q_to_vec_dot(pq, Q_q[i_gqa], D);

const uint32_t h = iq2 + i_gqa;
slope[i_gqa] = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
}

const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
Expand All @@ -12276,95 +12362,46 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;

const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
q_to_vec_dot(pq, Q_q, D);

// online softmax / attention
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
for (int64_t ic = 0; ic < nek1; ++ic) {
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) {
const float mp_value_base = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mp_value_base == -INFINITY) {
continue;
}

float s; // KQ value

const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);

s = s*scale; // scale KQ value

if (logit_softcap != 0.0f) {
s = logit_softcap*tanhf(s);
for (int i_gqa = 0; i_gqa < n_gqa; ++i_gqa) {
const float mv = mp_value_base * slope[i_gqa];
ggml_compute_forward_flash_attn_ext_f16_one_QKV(
Q_q[i_gqa], k_data, v_data, D, mv, scale, logit_softcap, v->type,
kq_vec_dot, v_to_float, VKQ16[i_gqa], VKQ32[i_gqa], V32[i_gqa], S+i_gqa, M+i_gqa);
}
}

s += mv; // apply mask

const float Mold = M;

float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
float vs = 1.0f; // post-softmax KQ value, expf(s - M)

const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));

for (int i = 0; i < n_gqa; ++i) {
if (v->type == GGML_TYPE_F16) {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);

// V = V*expf(Mold - M)
ggml_vec_scale_f16(D, VKQ16, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}

// V += v*expf(s - M)
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
} else {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);

// V = V*expf(Mold - M)
ggml_vec_scale_f32(D, VKQ32, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
for (int64_t d = 0; d < D; ++d) {
VKQ32[i][d] = GGML_FP16_TO_FP32(VKQ16[i][d]);
}

v_to_float(v_data, V32, D);

// V += v*expf(s - M)
ggml_vec_mad_f32(D, VKQ32, V32, vs);
}

S = S*ms + vs; // scale and increment sum with partial sum
}
// V /= S
const float S_inv = 1.0f/S[i];
ggml_vec_scale_f32(D, VKQ32[i], S_inv);

if (v->type == GGML_TYPE_F16) {
for (int64_t d = 0; d < D; ++d) {
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
}
}

// V /= S
const float S_inv = 1.0f/S;
ggml_vec_scale_f32(D, VKQ32, S_inv);

// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// dst indices
const int i1 = iq1;
const int i2 = iq2 + i;
const int i3 = iq3;

// original
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
// original
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));

// permute(0, 2, 1, 3)
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
// permute(0, 2, 1, 3)
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32[i], nb1);
}
}
}

Expand Down Expand Up @@ -15132,8 +15169,10 @@ struct ggml_cplan ggml_graph_plan(
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne00 = node->src[0]->ne[0]; // D

cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
const int64_t ne02 = node->src[0]->ne[2]; // n_head
const int64_t ne12 = node->src[1]->ne[2]; // n_head_kv
const int64_t n_gqa = ne02/ne12;
cur = 3*sizeof(float)*ne00*n_tasks*n_gqa; // 3x head size/thread
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
Expand Down