|
4 | 4 | #include "json-schema-to-grammar.h" |
5 | 5 | #include "llama.h" |
6 | 6 | #include "grammar-parser.h" |
| 7 | +#include "tool-call.hpp" |
7 | 8 |
|
8 | 9 | #ifndef NDEBUG |
9 | 10 | // crash the server in debug mode, otherwise send an http 500 error |
@@ -157,6 +158,7 @@ struct server_slot { |
157 | 158 | std::string generated_text; |
158 | 159 | std::vector<llama_token> cache_tokens; |
159 | 160 | std::vector<completion_token_output> generated_token_probs; |
| 161 | + enum llama_response_state response_state = LLAMA_RESPONSE_STATE_UNKNOWN; |
160 | 162 |
|
161 | 163 | bool infill = false; |
162 | 164 | bool embedding = false; |
@@ -207,6 +209,7 @@ struct server_slot { |
207 | 209 | infill = false; |
208 | 210 | ga_i = 0; |
209 | 211 | n_past_se = 0; |
| 212 | + response_state = LLAMA_RESPONSE_STATE_UNKNOWN; |
210 | 213 |
|
211 | 214 | generated_token_probs.clear(); |
212 | 215 | } |
@@ -625,6 +628,7 @@ struct server_context { |
625 | 628 | llama_model * model = nullptr; |
626 | 629 | llama_context * ctx = nullptr; |
627 | 630 | std::vector<llama_lora_adapter_container> lora_adapters; |
| 631 | + llama_tool_format tool_format = LLAMA_TOOL_FORMAT_NOT_SUPPORTED; |
628 | 632 |
|
629 | 633 | gpt_params params; |
630 | 634 |
|
@@ -1217,7 +1221,13 @@ struct server_context { |
1217 | 1221 | break; |
1218 | 1222 | } |
1219 | 1223 |
|
1220 | | - if (!incomplete) { |
| 1224 | + if (slot.response_state == LLAMA_RESPONSE_STATE_UNKNOWN) { |
| 1225 | + slot.response_state = check_response_state(tool_format, slot.generated_text); |
| 1226 | + } |
| 1227 | + |
| 1228 | + // if response is tool call, we cannot stream it |
| 1229 | + // instead, we wait for the full response, then extract JSON |
| 1230 | + if (!incomplete && slot.response_state == LLAMA_RESPONSE_STATE_TEXT) { |
1221 | 1231 | size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); |
1222 | 1232 |
|
1223 | 1233 | const std::string str_test = slot.generated_text.substr(pos); |
@@ -1247,9 +1257,7 @@ struct server_context { |
1247 | 1257 | if (slot.params.stream) { |
1248 | 1258 | send_partial_response(slot, result); |
1249 | 1259 | } |
1250 | | - } |
1251 | | - |
1252 | | - if (incomplete) { |
| 1260 | + } else { |
1253 | 1261 | slot.has_next_token = true; |
1254 | 1262 | } |
1255 | 1263 |
|
@@ -1396,6 +1404,10 @@ struct server_context { |
1396 | 1404 | {"multimodal", false} |
1397 | 1405 | }; |
1398 | 1406 |
|
| 1407 | + if (slot.response_state == LLAMA_RESPONSE_STATE_TOOL_CALL) { |
| 1408 | + res.data["tool_calls"] = parse_tool_response(tool_format, tkn.text_to_send); |
| 1409 | + } |
| 1410 | + |
1399 | 1411 | if (slot.sparams.n_probs > 0) { |
1400 | 1412 | const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false); |
1401 | 1413 | const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); |
@@ -1444,6 +1456,10 @@ struct server_context { |
1444 | 1456 | {"timings", slot.get_formated_timings()} |
1445 | 1457 | }; |
1446 | 1458 |
|
| 1459 | + if (slot.response_state == LLAMA_RESPONSE_STATE_TOOL_CALL) { |
| 1460 | + res.data["tool_calls"] = parse_tool_response(tool_format, slot.generated_text); |
| 1461 | + } |
| 1462 | + |
1447 | 1463 | if (slot.sparams.n_probs > 0) { |
1448 | 1464 | std::vector<completion_token_output> probs; |
1449 | 1465 | if (!slot.params.stream && slot.stopped_word) { |
@@ -2937,19 +2953,14 @@ int main(int argc, char ** argv) { |
2937 | 2953 | }; |
2938 | 2954 |
|
2939 | 2955 | const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) { |
2940 | | - std::string template_key = "tokenizer.chat_template", curr_tmpl; |
2941 | | - int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); |
2942 | | - if (tlen > 0) { |
2943 | | - std::vector<char> curr_tmpl_buf(tlen + 1, 0); |
2944 | | - if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { |
2945 | | - curr_tmpl = std::string(curr_tmpl_buf.data(), tlen); |
2946 | | - } |
2947 | | - } |
| 2956 | + std::string chat_tmpl = ctx_server.params.chat_template.empty() |
| 2957 | + ? llama_get_chat_template(ctx_server.model) |
| 2958 | + : ctx_server.params.chat_template; |
2948 | 2959 | json data = { |
2949 | 2960 | { "system_prompt", ctx_server.system_prompt.c_str() }, |
2950 | 2961 | { "default_generation_settings", ctx_server.default_generation_settings_for_props }, |
2951 | 2962 | { "total_slots", ctx_server.params.n_parallel }, |
2952 | | - { "chat_template", curr_tmpl.c_str() } |
| 2963 | + { "chat_template", chat_tmpl }, |
2953 | 2964 | }; |
2954 | 2965 |
|
2955 | 2966 | res.set_content(data.dump(), MIMETYPE_JSON); |
@@ -3056,7 +3067,13 @@ int main(int argc, char ** argv) { |
3056 | 3067 | res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); |
3057 | 3068 | return; |
3058 | 3069 | } |
3059 | | - json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); |
| 3070 | + json body = json::parse(req.body); |
| 3071 | + |
| 3072 | + if (body.contains("tools") && ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED) { |
| 3073 | + body["prompt"] = format_chat_with_tool(ctx_server.tool_format, body.at("messages"), body.at("tools")); |
| 3074 | + } |
| 3075 | + |
| 3076 | + json data = oaicompat_completion_params_parse(ctx_server.model, body, params.chat_template); |
3060 | 3077 |
|
3061 | 3078 | const int id_task = ctx_server.queue_tasks.get_new_id(); |
3062 | 3079 |
|
@@ -3423,11 +3440,15 @@ int main(int argc, char ** argv) { |
3423 | 3440 | } |
3424 | 3441 | } |
3425 | 3442 |
|
| 3443 | + // decide if we can enable tool calls |
| 3444 | + ctx_server.tool_format = get_tool_format(ctx_server.ctx); |
| 3445 | + |
3426 | 3446 | // print sample chat example to make it clear which template is used |
3427 | 3447 | { |
3428 | 3448 | LOG_INFO("chat template", { |
3429 | | - {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, |
3430 | | - {"built_in", params.chat_template.empty()}, |
| 3449 | + {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, |
| 3450 | + {"built_in", params.chat_template.empty()}, |
| 3451 | + {"tool_call_enabled", ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED }, |
3431 | 3452 | }); |
3432 | 3453 | } |
3433 | 3454 |
|
|
0 commit comments