Skip to content

Commit aadc68b

Browse files
committed
server : support both embeddings and completions with single model
ggml-ci
1 parent 0d03605 commit aadc68b

File tree

3 files changed

+48
-14
lines changed

3 files changed

+48
-14
lines changed

src/llama-context.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ llama_context::llama_context(
4141
cparams.yarn_beta_slow = params.yarn_beta_slow;
4242
cparams.defrag_thold = params.defrag_thold;
4343
cparams.embeddings = params.embeddings;
44+
cparams.embeddings_org = params.embeddings;
4445
cparams.offload_kqv = params.offload_kqv;
4546
cparams.flash_attn = params.flash_attn;
4647
cparams.no_perf = params.no_perf;
@@ -629,6 +630,12 @@ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void
629630
}
630631

631632
void llama_context::set_embeddings(bool value) {
633+
if (value && !cparams.embeddings_org) {
634+
LLAMA_LOG_ERROR("%s: cannot enable embeddings for this context (%s)\n",
635+
__func__, "https://github.com/ggml-org/llama.cpp/pull/14208");
636+
return;
637+
}
638+
632639
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
633640

634641
cparams.embeddings = value;

src/llama-cparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ struct llama_cparams {
2727
float defrag_thold;
2828

2929
bool embeddings;
30+
bool embeddings_org;
3031
bool causal_attn;
3132
bool offload_kqv;
3233
bool flash_attn;

tools/server/server.cpp

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,26 @@ enum error_type {
8888
ERROR_TYPE_NOT_SUPPORTED, // custom error
8989
};
9090

91+
static bool server_task_type_need_embd(server_task_type task_type) {
92+
switch (task_type) {
93+
case SERVER_TASK_TYPE_EMBEDDING:
94+
case SERVER_TASK_TYPE_RERANK:
95+
return true;
96+
default:
97+
return false;
98+
}
99+
}
100+
101+
static bool server_task_type_need_logits(server_task_type task_type) {
102+
switch (task_type) {
103+
case SERVER_TASK_TYPE_COMPLETION:
104+
case SERVER_TASK_TYPE_INFILL:
105+
return true;
106+
default:
107+
return false;
108+
}
109+
}
110+
91111
struct slot_params {
92112
bool stream = true;
93113
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
@@ -1330,6 +1350,14 @@ struct server_slot {
13301350
n_draft_accepted = 0;
13311351
}
13321352

1353+
bool need_embd() const {
1354+
return server_task_type_need_embd(task_type);
1355+
}
1356+
1357+
bool need_logits() const {
1358+
return server_task_type_need_logits(task_type);
1359+
}
1360+
13331361
bool can_batch_with(server_slot & other_slot) const {
13341362
return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora);
13351363
}
@@ -3095,6 +3123,13 @@ struct server_context {
30953123
continue;
30963124
}
30973125

3126+
// TODO: support memory-less logits computation
3127+
if (slot.need_logits() && !llama_get_memory(ctx)) {
3128+
slot.release();
3129+
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
3130+
continue;
3131+
}
3132+
30983133
if (!can_split()) {
30993134
if (slot.n_prompt_tokens > n_ubatch) {
31003135
slot.release();
@@ -3292,7 +3327,7 @@ struct server_context {
32923327
}
32933328

32943329
// embedding requires all tokens in the batch to be output
3295-
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING || slot.task_type == SERVER_TASK_TYPE_RERANK;
3330+
const bool need_embd = server_task_type_need_embd(slot.task_type);
32963331

32973332
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
32983333
slot.cache_tokens.push_back(cur_tok);
@@ -3348,13 +3383,13 @@ struct server_context {
33483383
if (slot_batched) {
33493384
// apply lora, only need to do it once per batch
33503385
common_set_adapter_lora(ctx, slot_batched->lora);
3351-
}
33523386

3353-
const bool do_encode = params_base.embedding;
3387+
llama_set_embeddings(ctx, slot_batched->need_embd());
3388+
}
33543389

33553390
// pad the batch so that batch.n_tokens >= n_slots
33563391
// TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689
3357-
if (do_encode) {
3392+
if (llama_get_embeddings(ctx)) {
33583393
const int n_slots = slots.size();
33593394

33603395
if (batch.n_tokens < n_slots) {
@@ -4175,11 +4210,6 @@ int main(int argc, char ** argv) {
41754210
oaicompat_type oaicompat) -> void {
41764211
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
41774212

4178-
if (ctx_server.params_base.embedding) {
4179-
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
4180-
return;
4181-
}
4182-
41834213
auto completion_id = gen_chatcmplid();
41844214
std::unordered_set<int> task_ids;
41854215
try {
@@ -4434,12 +4464,8 @@ int main(int argc, char ** argv) {
44344464
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
44354465
};
44364466

4437-
const auto handle_chat_completions = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
4467+
const auto handle_chat_completions = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
44384468
LOG_DBG("request: %s\n", req.body.c_str());
4439-
if (ctx_server.params_base.embedding) {
4440-
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
4441-
return;
4442-
}
44434469

44444470
auto body = json::parse(req.body);
44454471
std::vector<raw_buffer> files;

0 commit comments

Comments
 (0)