Skip to content

Commit e309c6a

Browse files
author
ochafik
committed
tool-call: integrate minja & tool-call to server when --jinja is set
1 parent 3cfc21e commit e309c6a

File tree

10 files changed

+514
-163
lines changed

10 files changed

+514
-163
lines changed

common/arg.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,13 +1844,21 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
18441844
}
18451845
}
18461846
).set_examples({LLAMA_EXAMPLE_SERVER}));
1847+
add_opt(llama_arg(
1848+
{"--jinja"},
1849+
"use jinja template for chat (default: disabled)",
1850+
[](gpt_params & params) {
1851+
params.use_jinja = true;
1852+
}
1853+
).set_examples({LLAMA_EXAMPLE_SERVER}));
18471854
add_opt(llama_arg(
18481855
{"--chat-template"}, "JINJA_TEMPLATE",
18491856
"set custom jinja chat template (default: template taken from model's metadata)\n"
18501857
"if suffix/prefix are specified, template will be disabled\n"
1851-
"only commonly used templates are accepted:\nhttps://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template",
1858+
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
1859+
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template",
18521860
[](gpt_params & params, const std::string & value) {
1853-
if (!llama_chat_verify_template(value)) {
1861+
if (!llama_chat_verify_template(value, params.use_jinja)) {
18541862
throw std::runtime_error(format(
18551863
"error: the supplied chat template is not supported: %s\n"
18561864
"note: llama.cpp does not use jinja parser, we only support commonly used templates\n",

common/common.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,16 +1510,20 @@ std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token>
15101510
// Chat template utils
15111511
//
15121512

1513-
bool llama_chat_verify_template(const std::string & tmpl) {
1513+
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
15141514
llama_chat_message chat[] = {{"user", "test"}};
1515-
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
1515+
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0, use_jinja);
15161516
return res >= 0;
15171517
}
15181518

15191519
std::string llama_chat_apply_template(const struct llama_model * model,
15201520
const std::string & tmpl,
15211521
const std::vector<llama_chat_msg> & msgs,
1522-
bool add_ass) {
1522+
bool add_ass,
1523+
bool use_jinja,
1524+
const std::string & tools,
1525+
const char * bos_token,
1526+
const char * eos_token) {
15231527
int alloc_size = 0;
15241528
bool fallback = false; // indicate if we must fallback to default chatml
15251529
std::vector<llama_chat_message> chat;
@@ -1532,7 +1536,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
15321536
std::vector<char> buf(alloc_size);
15331537

15341538
// run the first time to get the total output length
1535-
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
1539+
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools.empty() ? nullptr : tools.data(), bos_token, eos_token);
15361540

15371541
// error: chat template is not supported
15381542
if (res < 0) {
@@ -1542,7 +1546,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
15421546
throw std::runtime_error("this custom template is not supported");
15431547
} else {
15441548
// If the built-in template is not supported, we default to chatml
1545-
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
1549+
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, bos_token, eos_token);
15461550
fallback = true;
15471551
}
15481552
}
@@ -1553,7 +1557,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
15531557
res = llama_chat_apply_template(
15541558
fallback ? nullptr : model,
15551559
fallback ? "chatml" : ptr_tmpl,
1556-
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
1560+
chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, bos_token, eos_token);
15571561
}
15581562

15591563
std::string formatted_chat(buf.data(), res);
@@ -1564,17 +1568,21 @@ std::string llama_chat_format_single(const struct llama_model * model,
15641568
const std::string & tmpl,
15651569
const std::vector<llama_chat_msg> & past_msg,
15661570
const llama_chat_msg & new_msg,
1567-
bool add_ass) {
1571+
bool add_ass,
1572+
bool use_jinja,
1573+
const std::string & tools,
1574+
const char * bos_token,
1575+
const char * eos_token) {
15681576
std::ostringstream ss;
1569-
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
1577+
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, bos_token, eos_token);
15701578
std::vector<llama_chat_msg> chat_new(past_msg);
15711579
// if the past_msg ends with a newline, we must preserve it in the formatted version
15721580
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
15731581
ss << "\n";
15741582
};
15751583
// format chat with new_msg
15761584
chat_new.push_back(new_msg);
1577-
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
1585+
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, bos_token, eos_token);
15781586
// get the diff part
15791587
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
15801588
return ss.str();

common/common.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ struct gpt_params {
285285
std::string public_path = ""; // NOLINT
286286
std::string chat_template = ""; // NOLINT
287287
std::string system_prompt = ""; // NOLINT
288+
bool use_jinja = false; // NOLINT
288289
bool enable_chat_template = true;
289290

290291
std::vector<std::string> api_keys;
@@ -469,29 +470,43 @@ std::string llama_detokenize(
469470
// Chat template utils
470471
//
471472

472-
// same with llama_chat_message, but uses std::string
473+
// same as llama_chat_message, but uses std::string and std::vector
473474
struct llama_chat_msg {
474475
std::string role;
475476
std::string content;
477+
std::string tool;
478+
struct llama_tool_call {
479+
std::string name;
480+
std::string arguments;
481+
};
482+
std::vector<llama_tool_call> tool_calls;
476483
};
477484

478485
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
479-
bool llama_chat_verify_template(const std::string & tmpl);
486+
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja = false);
480487

481488
// CPP wrapper for llama_chat_apply_template
482489
// If the built-in template is not supported, we default to chatml
483490
// If the custom "tmpl" is not supported, we throw an error
484491
std::string llama_chat_apply_template(const struct llama_model * model,
485492
const std::string & tmpl,
486493
const std::vector<llama_chat_msg> & chat,
487-
bool add_ass);
494+
bool add_ass,
495+
bool use_jinja = false,
496+
const std::string & tools = "",
497+
const char * bos_token = nullptr,
498+
const char * eos_token = nullptr);
488499

489500
// Format single message, while taking into account the position of that message in chat history
490501
std::string llama_chat_format_single(const struct llama_model * model,
491502
const std::string & tmpl,
492503
const std::vector<llama_chat_msg> & past_msg,
493504
const llama_chat_msg & new_msg,
494-
bool add_ass);
505+
bool add_ass,
506+
bool use_jinja = false,
507+
const std::string & tools = "",
508+
const char * bos_token = nullptr,
509+
const char * eos_token = nullptr);
495510

496511
// Returns an example of formatted chat
497512
std::string llama_chat_format_example(const struct llama_model * model,

examples/server/server.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2781,6 +2781,8 @@ int main(int argc, char ** argv) {
27812781
{ "system_prompt", ctx_server.system_prompt.c_str() },
27822782
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
27832783
{ "total_slots", ctx_server.params.n_parallel },
2784+
{ "bos_token", llama_token_to_piece(ctx_server.ctx, llama_token_bos(ctx_server.model), true) },
2785+
{ "eos_token", llama_token_to_piece(ctx_server.ctx, llama_token_eos(ctx_server.model), true) },
27842786
{ "chat_template", curr_tmpl.c_str() },
27852787
};
27862788

@@ -2854,7 +2856,7 @@ int main(int argc, char ** argv) {
28542856
return;
28552857
}
28562858

2857-
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
2859+
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template, params.use_jinja);
28582860

28592861
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
28602862
ctx_server.queue_results.add_waiting_tasks(tasks);

examples/server/tests/features/steps/steps.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def step_server_config(context, server_fqdn: str, server_port: str):
7575
context.server_seed = None
7676
context.user_api_key = None
7777
context.response_format = None
78+
context.tools = None
79+
context.tool_choice = None
7880
context.temperature = None
7981
context.lora_file = None
8082
context.disable_ctx_shift = False
@@ -363,6 +365,13 @@ def step_max_tokens(context, max_tokens):
363365
def step_response_format(context, response_format):
364366
context.response_format = json.loads(response_format)
365367

368+
@step('tools {tools}')
369+
def step_tools(context, tools):
370+
context.tools = json.loads(tools)
371+
372+
@step('tool choice {tool_choice}')
373+
def step_tool_choice(context, tool_choice):
374+
context.tool_choice = tool_choice
366375

367376
@step('{temperature:f} temperature')
368377
def step_temperature(context, temperature):
@@ -497,6 +506,11 @@ async def step_oai_chat_completions(context, api_error):
497506
response_format=context.response_format
498507
if hasattr(context, 'response_format') else None,
499508

509+
tools=context.tools
510+
if hasattr(context, 'tools') else None,
511+
512+
tool_choice=context.tool_choice,
513+
500514
user_api_key=context.user_api_key
501515
if hasattr(context, 'user_api_key') else None,
502516

@@ -567,6 +581,9 @@ async def step_oai_chat_completions(context):
567581
if hasattr(context, 'enable_streaming') else None,
568582
response_format=context.response_format
569583
if hasattr(context, 'response_format') else None,
584+
tools=context.tools
585+
if hasattr(context, 'tools') else None,
586+
tool_choice=context.tool_choice,
570587
user_api_key=context.user_api_key
571588
if hasattr(context, 'user_api_key') else None)
572589

@@ -580,16 +597,18 @@ async def step_oai_chat_completions(context):
580597
context.base_url,
581598
'/chat/completions',
582599
True, # async_client
583-
model=context.model
584-
if hasattr(context, 'model') else None,
585-
n_predict=context.n_predict
586-
if hasattr(context, 'n_predict') else None,
600+
model=context.model,
601+
# if hasattr(context, 'model') else None,
602+
n_predict=context.n_predict,
603+
# if hasattr(context, 'n_predict') else None,
587604
enable_streaming=context.enable_streaming
588605
if hasattr(context, 'enable_streaming') else None,
589-
response_format=context.response_format
590-
if hasattr(context, 'response_format') else None,
591-
user_api_key=context.user_api_key
592-
if hasattr(context, 'user_api_key') else None)
606+
response_format=context.response_format,
607+
# if hasattr(context, 'response_format') else None,
608+
tools=context.tools,# if hasattr(context, 'tools') else None,
609+
tool_choice=context.tool_choice, # if hasattr(context, 'tool_choice') else None,
610+
user_api_key=context.user_api_key)
611+
# if hasattr(context, 'user_api_key') else None)
593612

594613

595614
@step('all prompts are predicted')
@@ -974,6 +993,8 @@ async def oai_chat_completions(user_prompt,
974993
n_predict=None,
975994
enable_streaming=None,
976995
response_format=None,
996+
tools=None,
997+
tool_choice=None,
977998
user_api_key=None,
978999
expect_api_error=None) -> int | dict[str, Any]:
9791000
if debug:
@@ -1001,6 +1022,10 @@ async def oai_chat_completions(user_prompt,
10011022
}
10021023
if response_format is not None:
10031024
payload['response_format'] = response_format
1025+
if tools is not None:
1026+
payload['tools'] = tools
1027+
if tool_choice is not None:
1028+
payload['tool_choice'] = tool_choice
10041029
completion_response = {
10051030
'content': '',
10061031
'timings': {
@@ -1065,6 +1090,8 @@ async def oai_chat_completions(user_prompt,
10651090
max_tokens=n_predict,
10661091
stream=enable_streaming,
10671092
response_format=payload.get('response_format') or openai.NOT_GIVEN,
1093+
tools=payload.get('tools'),
1094+
tool_choice=payload.get('tool_choice'),
10681095
seed=seed,
10691096
temperature=payload['temperature']
10701097
)

0 commit comments

Comments
 (0)