File tree Expand file tree Collapse file tree 4 files changed +42
-2
lines changed Expand file tree Collapse file tree 4 files changed +42
-2
lines changed Original file line number Diff line number Diff line change @@ -49,6 +49,26 @@ def test_embedding_multiple():
4949 assert len (d ['embedding' ]) > 1
5050
5151
52+ def test_embedding_multiple_with_fa ():
53+ server = ServerPreset .bert_bge_small_with_fa ()
54+ server .pooling = 'last'
55+ server .start ()
56+ # one of these should trigger the FA branch (i.e. context size % 256 == 0)
57+ res = server .make_request ("POST" , "/v1/embeddings" , data = {
58+ "input" : [
59+ "a " * 253 ,
60+ "b " * 254 ,
61+ "c " * 255 ,
62+ "d " * 256 ,
63+ ],
64+ })
65+ assert res .status_code == 200
66+ assert len (res .body ['data' ]) == 4
67+ for d in res .body ['data' ]:
68+ assert 'embedding' in d
69+ assert len (d ['embedding' ]) > 1
70+
71+
5272@pytest .mark .parametrize (
5373 "input,is_multi_prompt" ,
5474 [
Original file line number Diff line number Diff line change @@ -323,6 +323,21 @@ def bert_bge_small() -> ServerProcess:
323323 server .server_embeddings = True
324324 return server
325325
326+ @staticmethod
327+ def bert_bge_small_with_fa () -> ServerProcess :
328+ server = ServerProcess ()
329+ server .model_hf_repo = "ggml-org/models"
330+ server .model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
331+ server .model_alias = "bert-bge-small"
332+ server .n_ctx = 1024
333+ server .n_batch = 300
334+ server .n_ubatch = 300
335+ server .n_slots = 2
336+ server .fa = True
337+ server .seed = 42
338+ server .server_embeddings = True
339+ return server
340+
326341 @staticmethod
327342 def tinyllama_infill () -> ServerProcess :
328343 server = ServerProcess ()
Original file line number Diff line number Diff line change @@ -6721,8 +6721,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
67216721 ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu (k->type )->vec_dot ;
67226722 ggml_to_float_t const v_to_float = ggml_get_type_traits (v->type )->to_float ;
67236723
6724- GGML_ASSERT ( q_to_vec_dot && " fattn: unsupported K-type" );
6725- GGML_ASSERT (v->type == GGML_TYPE_F32 || v_to_float && " fattn: unsupported V-type" );
6724+ GGML_ASSERT (( q_to_vec_dot) && " fattn: unsupported K-type" );
6725+ GGML_ASSERT (( v->type == GGML_TYPE_F32 || v_to_float ) && " fattn: unsupported V-type" );
67266726
67276727 // loop over n_batch and n_head
67286728 for (int ir = ir0; ir < ir1; ++ir) {
Original file line number Diff line number Diff line change @@ -1345,6 +1345,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
13451345 case GGML_OP_ARANGE:
13461346 return true ;
13471347 case GGML_OP_FLASH_ATTN_EXT:
1348+ if (op->src [0 ]->ne [0 ] == 32 ) {
1349+ // head size == 32 (e.g. bert-bge-small)
1350+ // TODO: not sure if it is worth adding kernels for this size
1351+ return false ;
1352+ }
13481353 if (op->src [1 ]->type != op->src [2 ]->type ) {
13491354 return false ;
13501355 }
You can’t perform that action at this time.
0 commit comments