Skip to content
10 changes: 5 additions & 5 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1791,11 +1791,11 @@ extern "C" {

#define GGML_KQ_MASK_PAD 64

// q: [n_embd, n_batch, n_head, 1]
// k: [n_embd, n_kv, n_head_kv, 1]
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
// q: [n_embd_k, n_batch, n_head, 1]
// k: [n_embd_k, n_kv, n_head_kv, 1]
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
struct ggml_tensor * q,
Expand Down
50 changes: 25 additions & 25 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -12238,23 +12238,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int ith = params->ith;
const int nth = params->nth;

const int64_t D = neq0;
const int64_t N = neq1;
const int64_t DK = nek0;
const int64_t DV = nev0;
const int64_t N = neq1;

GGML_ASSERT(ne0 == D);
GGML_ASSERT(ne0 == DV);
GGML_ASSERT(ne2 == N);

// input tensor rows must be contiguous
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type));

GGML_ASSERT(neq0 == D);
GGML_ASSERT(nek0 == D);
GGML_ASSERT(nev0 == D);
GGML_ASSERT(neq0 == DK);
GGML_ASSERT(nek0 == DK);
GGML_ASSERT(nev0 == DV);

GGML_ASSERT(neq1 == N);
GGML_ASSERT(nev0 == D);

// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
Expand Down Expand Up @@ -12320,15 +12320,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value

float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16

if (v->type == GGML_TYPE_F16) {
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
} else {
memset(VKQ32, 0, D*sizeof(float));
memset(VKQ32, 0, DV*sizeof(float));
}

const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
Expand All @@ -12342,7 +12342,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int iv2 = iq2 / rv2;

const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
q_to_vec_dot(pq, Q_q, D);
q_to_vec_dot(pq, Q_q, DK);

// online softmax / attention
// loop over n_kv and n_head_kv
Expand All @@ -12356,7 +12356,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float s; // KQ value

const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);

s = s*scale; // scale KQ value

Expand All @@ -12380,45 +12380,45 @@ static void ggml_compute_forward_flash_attn_ext_f16(
ms = expf(Mold - M);

// V = V*expf(Mold - M)
ggml_vec_scale_f16(D, VKQ16, ms);
ggml_vec_scale_f16(DV, VKQ16, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}

// V += v*expf(s - M)
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
} else {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);

// V = V*expf(Mold - M)
ggml_vec_scale_f32(D, VKQ32, ms);
ggml_vec_scale_f32(DV, VKQ32, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}

v_to_float(v_data, V32, D);
v_to_float(v_data, V32, DV);

// V += v*expf(s - M)
ggml_vec_mad_f32(D, VKQ32, V32, vs);
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
}

S = S*ms + vs; // scale and increment sum with partial sum
}

if (v->type == GGML_TYPE_F16) {
for (int64_t d = 0; d < D; ++d) {
for (int64_t d = 0; d < DV; ++d) {
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
}
}

// V /= S
const float S_inv = 1.0f/S;
ggml_vec_scale_f32(D, VKQ32, S_inv);
ggml_vec_scale_f32(DV, VKQ32, S_inv);

// dst indices
const int i1 = iq1;
Expand Down Expand Up @@ -15277,7 +15277,6 @@ struct ggml_cplan ggml_graph_plan(
size_t cur = 0;

if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {

switch (node->op) {
case GGML_OP_CPY:
case GGML_OP_DUP:
Expand Down Expand Up @@ -15386,9 +15385,10 @@ struct ggml_cplan ggml_graph_plan(
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne00 = node->src[0]->ne[0]; // D
const int64_t ne10 = node->src[1]->ne[0]; // DK
const int64_t ne20 = node->src[2]->ne[0]; // DV

cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3232,6 +3232,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[0]->ne[3] != 1) {
return false;
}
Expand Down
9 changes: 6 additions & 3 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,12 @@ typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb_12_1;
uint64_t nb_12_2;
uint64_t nb_12_3;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
uint64_t nb31;
int32_t ne1;
int32_t ne2;
Expand Down
Loading
Loading