Skip to content

Commit 090a113

Browse files
committed
remove task inf_type
1 parent e721f4c commit 090a113

File tree

1 file changed

+35
-33
lines changed

1 file changed

+35
-33
lines changed

examples/server/server.cpp

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ enum server_state {
5454
};
5555

5656
enum server_task_type {
57-
SERVER_TASK_TYPE_INFERENCE,
57+
SERVER_TASK_TYPE_COMPLETION,
58+
SERVER_TASK_TYPE_EMBEDDING,
59+
SERVER_TASK_TYPE_RERANK,
60+
SERVER_TASK_TYPE_INFILL,
5861
SERVER_TASK_TYPE_CANCEL,
5962
SERVER_TASK_TYPE_NEXT_RESPONSE,
6063
SERVER_TASK_TYPE_METRICS,
@@ -64,13 +67,6 @@ enum server_task_type {
6467
SERVER_TASK_TYPE_SET_LORA,
6568
};
6669

67-
enum server_task_inf_type {
68-
SERVER_TASK_INF_TYPE_COMPLETION,
69-
SERVER_TASK_INF_TYPE_EMBEDDING,
70-
SERVER_TASK_INF_TYPE_RERANK,
71-
SERVER_TASK_INF_TYPE_INFILL,
72-
};
73-
7470
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
7571
enum error_type {
7672
ERROR_TYPE_INVALID_REQUEST,
@@ -163,8 +159,7 @@ struct server_task {
163159
int id = -1; // to be filled by server_queue
164160
int index = -1; // used when there are multiple prompts (batch request)
165161

166-
server_task_type type;
167-
server_task_inf_type inf_type;
162+
server_task_type type;
168163

169164
// used by SERVER_TASK_TYPE_CANCEL
170165
int id_target = -1;
@@ -185,9 +180,7 @@ struct server_task {
185180
// used by SERVER_TASK_TYPE_METRICS
186181
bool metrics_reset_bucket = false;
187182

188-
server_task(
189-
server_task_type type,
190-
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION) : type(type), inf_type(inf_type) {}
183+
server_task(server_task_type type) : type(type) {}
191184

192185
static slot_params params_from_json_cmpl(
193186
const llama_model * model,
@@ -893,6 +886,9 @@ struct server_slot {
893886
int id;
894887
int id_task = -1;
895888

889+
// only used for completion/embedding/infill/rerank
890+
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
891+
896892
llama_batch batch_spec = {};
897893

898894
llama_context * ctx = nullptr;
@@ -931,8 +927,6 @@ struct server_slot {
931927
llama_tokens cache_tokens;
932928
std::vector<completion_token_output> generated_token_probs;
933929

934-
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
935-
936930
bool has_next_token = true;
937931
bool has_new_line = false;
938932
bool truncated = false;
@@ -972,11 +966,15 @@ struct server_slot {
972966
n_past = 0;
973967
n_sent_text = 0;
974968
n_sent_token_probs = 0;
975-
inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
969+
task_type = SERVER_TASK_TYPE_COMPLETION;
976970

977971
generated_token_probs.clear();
978972
}
979973

974+
bool is_non_causal() const {
975+
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
976+
}
977+
980978
bool has_budget(const common_params & global_params) {
981979
if (params.n_predict == -1 && global_params.n_predict == -1) {
982980
return true; // limitless
@@ -1088,6 +1086,7 @@ struct server_slot {
10881086
{"n_ctx", n_ctx},
10891087
{"speculative", can_speculate()},
10901088
{"is_processing", is_processing()},
1089+
{"non_causal", is_non_causal()},
10911090
{"params", params.to_json()},
10921091
{"prompt", common_detokenize(ctx, prompt_tokens)},
10931092
{"next_token",
@@ -1653,8 +1652,8 @@ struct server_context {
16531652
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
16541653
slot.reset();
16551654
slot.id_task = task.id;
1656-
slot.inf_type = task.inf_type;
16571655
slot.index = task.index;
1656+
slot.task_type = task.type;
16581657
slot.params = std::move(task.params);
16591658
slot.prompt_tokens = std::move(task.prompt_tokens);
16601659

@@ -2120,7 +2119,10 @@ struct server_context {
21202119

21212120
void process_single_task(server_task task) {
21222121
switch (task.type) {
2123-
case SERVER_TASK_TYPE_INFERENCE:
2122+
case SERVER_TASK_TYPE_COMPLETION:
2123+
case SERVER_TASK_TYPE_INFILL:
2124+
case SERVER_TASK_TYPE_EMBEDDING:
2125+
case SERVER_TASK_TYPE_RERANK:
21242126
{
21252127
const int id_slot = task.id_selected_slot;
21262128

@@ -2462,7 +2464,7 @@ struct server_context {
24622464
continue;
24632465
}
24642466

2465-
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
2467+
if (slot.is_non_causal()) {
24662468
if (slot.n_prompt_tokens > n_ubatch) {
24672469
slot.release();
24682470
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
@@ -2577,18 +2579,15 @@ struct server_context {
25772579
}
25782580

25792581
// non-causal tasks require to fit the entire prompt in the physical batch
2580-
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
2582+
if (slot.is_non_causal()) {
25812583
// cannot fit the prompt in the current batch - will try next iter
25822584
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
25832585
continue;
25842586
}
25852587
}
25862588

25872589
// check that we are in the right batch_type, if not defer the slot
2588-
const bool slot_type =
2589-
slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
2590-
slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
2591-
2590+
int slot_type = slot.is_non_causal();
25922591
if (batch_type == -1) {
25932592
batch_type = slot_type;
25942593
} else if (batch_type != slot_type) {
@@ -2705,15 +2704,15 @@ struct server_context {
27052704
}
27062705

27072706
if (slot.state == SLOT_STATE_DONE_PROMPT) {
2708-
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
2707+
if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
27092708
// prompt evaluated for embedding
27102709
send_embedding(slot, batch_view);
27112710
slot.release();
27122711
slot.i_batch = -1;
27132712
continue; // continue loop of slots
27142713
}
27152714

2716-
if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
2715+
if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
27172716
send_rerank(slot, batch_view);
27182717
slot.release();
27192718
slot.i_batch = -1;
@@ -3352,11 +3351,13 @@ int main(int argc, char ** argv) {
33523351
// handle completion-like requests (completion, chat, infill)
33533352
// we can optionally provide a custom format for partial results and final results
33543353
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
3355-
server_task_inf_type inf_type,
3354+
server_task_type type,
33563355
json & data,
33573356
httplib::Response & res,
33583357
bool oaicompat = false,
33593358
bool oaicompat_chat = false) {
3359+
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
3360+
33603361
if (ctx_server.params_base.embedding) {
33613362
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
33623363
return;
@@ -3369,7 +3370,8 @@ int main(int argc, char ** argv) {
33693370
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true);
33703371
tasks.reserve(tokenized_prompts.size());
33713372
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
3372-
server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, inf_type);
3373+
server_task task = server_task(type);
3374+
33733375
task.id = ctx_server.queue_tasks.get_new_id();
33743376
task.index = i;
33753377

@@ -3450,7 +3452,7 @@ int main(int argc, char ** argv) {
34503452
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
34513453
json data = json::parse(req.body);
34523454
return handle_completions_generic(
3453-
SERVER_TASK_INF_TYPE_COMPLETION,
3455+
SERVER_TASK_TYPE_COMPLETION,
34543456
data,
34553457
res,
34563458
/* oaicompat */ false,
@@ -3504,7 +3506,7 @@ int main(int argc, char ** argv) {
35043506
}
35053507
data["input_extra"] = input_extra; // default to empty array if it's not exist
35063508

3507-
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
3509+
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
35083510
};
35093511

35103512
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
@@ -3515,7 +3517,7 @@ int main(int argc, char ** argv) {
35153517

35163518
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
35173519
return handle_completions_generic(
3518-
SERVER_TASK_INF_TYPE_COMPLETION,
3520+
SERVER_TASK_TYPE_COMPLETION,
35193521
data,
35203522
res,
35213523
/* oaicompat */ true,
@@ -3616,7 +3618,7 @@ int main(int argc, char ** argv) {
36163618
std::vector<server_task> tasks;
36173619
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
36183620
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
3619-
server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_EMBEDDING);
3621+
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
36203622
task.id = ctx_server.queue_tasks.get_new_id();
36213623
task.index = i;
36223624
task.prompt_tokens = std::move(tokenized_prompts[i]);
@@ -3698,7 +3700,7 @@ int main(int argc, char ** argv) {
36983700
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true);
36993701
tasks.reserve(tokenized_docs.size());
37003702
for (size_t i = 0; i < tokenized_docs.size(); i++) {
3701-
server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_RERANK);
3703+
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
37023704
task.id = ctx_server.queue_tasks.get_new_id();
37033705
task.index = i;
37043706
task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);

0 commit comments

Comments
 (0)