Skip to content

Commit 637503b

Browse files
committed
llama : fix FA when KV cache is not used (i.e. embeddings) (llama/12825)
* ggml : FA supports F32 V * graph : cast KV to F16 when the KV cache is not used ggml-ci * server : add test that exercises embeddings with FA enabled ggml-ci
1 parent 2af91a9 commit 637503b

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

src/ggml-cpu/ops.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6769,8 +6769,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
67696769
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
67706770
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
67716771

6772-
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
6773-
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
6772+
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
6773+
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
67746774

67756775
// loop over n_batch and n_head
67766776
for (int ir = ir0; ir < ir1; ++ir) {
@@ -6866,10 +6866,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
68666866
vs = expf(s - M);
68676867
}
68686868

6869-
v_to_float(v_data, V32, DV);
6870-
68716869
// V += v*expf(s - M)
6872-
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
6870+
if (v_to_float) {
6871+
v_to_float(v_data, V32, DV);
6872+
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
6873+
} else {
6874+
// V is F32
6875+
ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
6876+
}
68736877
}
68746878

68756879
S = S*ms + vs; // scale and increment sum with partial sum

src/ggml-metal/ggml-metal.m

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,6 +1346,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
13461346
case GGML_OP_ARANGE:
13471347
return true;
13481348
case GGML_OP_FLASH_ATTN_EXT:
1349+
if (op->src[0]->ne[0] == 32) {
1350+
// head size == 32 (e.g. bert-bge-small)
1351+
// TODO: not sure if it is worth adding kernels for this size
1352+
return false;
1353+
}
13491354
if (op->src[1]->type != op->src[2]->type) {
13501355
return false;
13511356
}

0 commit comments

Comments
 (0)