Skip to content

Commit 36c776f

Browse files
author
ochafik
committed
Finish renaming of chat inputs vs. params [skip ci]
1 parent ed7c622 commit 36c776f

File tree

3 files changed

+67
-62
lines changed

3 files changed

+67
-62
lines changed

common/common.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,12 +1776,12 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
17761776
if (use_jinja) {
17771777
try {
17781778
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
1779-
common_chat_inputs params;
1780-
params.messages = json::array({{
1779+
common_chat_inputs inputs;
1780+
inputs.messages = json::array({{
17811781
{"role", "user"},
17821782
{"content", "test"},
17831783
}});
1784-
common_chat_params_init(chat_template, params);
1784+
common_chat_params_init(chat_template, inputs);
17851785
return true;
17861786
} catch (const std::exception & e) {
17871787
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
@@ -1803,11 +1803,10 @@ std::string common_chat_apply_template(
18031803
for (const auto & msg : msgs) {
18041804
messages.push_back({{"role", msg.role}, {"content", msg.content}});
18051805
}
1806-
common_chat_inputs params;
1807-
params.messages = messages;
1808-
params.add_generation_prompt = add_ass;
1809-
auto data = common_chat_params_init(tmpl, params);
1810-
return data.prompt;
1806+
common_chat_inputs inputs;
1807+
inputs.messages = messages;
1808+
inputs.add_generation_prompt = add_ass;
1809+
return common_chat_params_init(tmpl, inputs).prompt;
18111810
}
18121811

18131812
int alloc_size = 0;

examples/server/server.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,16 +1824,16 @@ struct server_context {
18241824

18251825
if (use_jinja) {
18261826
auto templates = common_chat_templates_from_model(model, "");
1827-
common_chat_inputs params;
1828-
params.messages = json::array({{
1827+
common_chat_inputs inputs;
1828+
inputs.messages = json::array({{
18291829
{"role", "user"},
18301830
{"content", "test"},
18311831
}});
18321832
GGML_ASSERT(templates.template_default);
18331833
try {
1834-
common_chat_params_init(*templates.template_default, params);
1834+
common_chat_params_init(*templates.template_default, inputs);
18351835
if (templates.template_tool_use) {
1836-
common_chat_params_init(*templates.template_tool_use, params);
1836+
common_chat_params_init(*templates.template_tool_use, inputs);
18371837
}
18381838
return true;
18391839
} catch (const std::exception & e) {
@@ -3787,10 +3787,10 @@ int main(int argc, char ** argv) {
37873787
std::vector<server_task> tasks;
37883788

37893789
try {
3790-
common_chat_params chat_data;
3790+
common_chat_params chat_params;
37913791
bool add_special = false;
37923792
if (tmpl && ctx_server.params_base.use_jinja) {
3793-
chat_data = common_chat_params_init(*tmpl, {
3793+
chat_params = common_chat_params_init(*tmpl, {
37943794
/* .messages = */ json_value(data, "messages", json::array()),
37953795
/* .tools = */ json_value(data, "tools", json()),
37963796
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
@@ -3799,28 +3799,28 @@ int main(int argc, char ** argv) {
37993799
/* .stream = */ json_value(data, "stream", false),
38003800
/* .grammar = */ json_value(data, "grammar", std::string("")),
38013801
});
3802-
LOG_INF("Chat format: %s\n", chat_data.format.c_str());
3803-
LOG_DBG("Prompt: %s\n", chat_data.prompt.get<std::string>().c_str());
3804-
LOG_DBG("Grammar: %s\n", chat_data.grammar.c_str());
3802+
LOG_INF("Chat format: %s\n", chat_params.format.c_str());
3803+
LOG_DBG("Prompt: %s\n", chat_params.prompt.get<std::string>().c_str());
3804+
LOG_DBG("Grammar: %s\n", chat_params.grammar.c_str());
38053805
if (data.contains("grammar")) {
3806-
if (!chat_data.grammar.empty()) {
3806+
if (!chat_params.grammar.empty()) {
38073807
throw std::runtime_error("Cannot provide grammar and tools");
38083808
}
3809-
chat_data.grammar = data.at("grammar");
3809+
chat_params.grammar = data.at("grammar");
38103810
}
38113811
// TODO: move inside minja:chat_template?
38123812
add_special = tmpl->source().find("eos_token") == std::string::npos &&
38133813
tmpl->source().find("bos_token") == std::string::npos;
38143814
} else {
38153815
add_special = true;
3816-
chat_data.prompt = data.at("prompt");
3816+
chat_params.prompt = data.at("prompt");
38173817
if (data.contains("grammar")) {
3818-
chat_data.grammar = data.at("grammar");
3818+
chat_params.grammar = data.at("grammar");
38193819
} else if (data.contains("json_schema")) {
3820-
chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
3820+
chat_params.grammar = json_schema_to_grammar(data.at("json_schema"));
38213821
}
38223822
}
3823-
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, add_special, true);
3823+
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_params.prompt, add_special, true);
38243824
tasks.reserve(tokenized_prompts.size());
38253825
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
38263826
server_task task = server_task(type);
@@ -3838,9 +3838,9 @@ int main(int argc, char ** argv) {
38383838
// OAI-compat
38393839
task.params.oaicompat = oaicompat;
38403840
task.params.oaicompat_cmpl_id = completion_id;
3841-
task.params.sampling.grammar = chat_data.grammar;
3842-
task.params.sampling.grammar_lazy = chat_data.grammar_lazy;
3843-
for (const auto & trigger : chat_data.grammar_triggers) {
3841+
task.params.sampling.grammar = chat_params.grammar;
3842+
task.params.sampling.grammar_lazy = chat_params.grammar_lazy;
3843+
for (const auto & trigger : chat_params.grammar_triggers) {
38443844
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
38453845
if (ids.size() == 1) {
38463846
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
@@ -3850,8 +3850,8 @@ int main(int argc, char ** argv) {
38503850
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
38513851
task.params.sampling.grammar_trigger_words.push_back(trigger);
38523852
}
3853-
task.params.antiprompt = chat_data.additional_stops;
3854-
task.params.chat_parser = chat_data.parser;
3853+
task.params.antiprompt = chat_params.additional_stops;
3854+
task.params.chat_parser = chat_params.parser;
38553855
if (task.params.sampling.grammar_lazy) {
38563856
GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0);
38573857
}

tests/test-chat.cpp

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -169,18 +169,18 @@ struct delta_data {
169169
};
170170

171171
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
172-
common_chat_inputs params;
173-
params.parallel_tool_calls = true;
174-
params.messages = json::array();
175-
params.messages.push_back(user_message);
176-
params.tools = tools;
177-
auto prefix_data = common_chat_params_init(tmpl, params);
178-
params.messages.push_back(delta_message);
179-
params.add_generation_prompt = false;
180-
auto full_data = common_chat_params_init(tmpl, params);
181-
182-
std::string prefix = prefix_data.prompt;
183-
std::string full = full_data.prompt;
172+
common_chat_inputs inputs;
173+
inputs.parallel_tool_calls = true;
174+
inputs.messages = json::array();
175+
inputs.messages.push_back(user_message);
176+
inputs.tools = tools;
177+
auto params_prefix = common_chat_params_init(tmpl, inputs);
178+
inputs.messages.push_back(delta_message);
179+
inputs.add_generation_prompt = false;
180+
auto params_full = common_chat_params_init(tmpl, inputs);
181+
182+
std::string prefix = params_prefix.prompt;
183+
std::string full = params_full.prompt;
184184

185185
// Check full starts with prefix
186186
if (full.find(prefix) != 0) {
@@ -203,7 +203,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
203203
break;
204204
}
205205
}
206-
return {delta, full_data.grammar, full_data.parser};
206+
return {delta, params_full.grammar, params_full.parser};
207207
}
208208

209209
/*
@@ -220,12 +220,6 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
220220
};
221221

222222
for (const auto & tool_choice : json({"auto", "required"})) {
223-
common_chat_inputs params;
224-
params.tool_choice = tool_choice;
225-
params.parallel_tool_calls = true;
226-
params.messages = json {user_message, test_message};
227-
params.tools = tools;
228-
229223
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools);
230224
if (!expected_delta.empty()) {
231225
assert_equals(expected_delta, data.delta);
@@ -309,17 +303,18 @@ static void test_template_output_parsers() {
309303
tools_params.tools.push_back(special_function_tool);
310304

311305
auto describe = [](const common_chat_template & tmpl, const common_chat_inputs & params) {
312-
auto data = common_chat_params_init(tmpl, params);
313-
return data.format;
306+
return common_chat_params_init(tmpl, params).format;
314307
};
315308

316309
{
317-
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
310+
const common_chat_template tmpl(read_file(
311+
"models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
318312
std::vector<std::string> end_tokens { "<end_of_turn>" };
319313

320314
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
321315
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
322-
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
316+
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file(
317+
"models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
323318

324319
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
325320
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser(
@@ -340,7 +335,8 @@ static void test_template_output_parsers() {
340335
"}");
341336
}
342337
{
343-
const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
338+
const common_chat_template tmpl(read_file(
339+
"models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
344340
std::vector<std::string> end_tokens { "</s>" };
345341

346342
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
@@ -351,12 +347,15 @@ static void test_template_output_parsers() {
351347
/* skip_grammar_test= */ true);
352348
}
353349
{
354-
const common_chat_template tmpl(read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
350+
const common_chat_template tmpl(read_file(
351+
"models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
355352
std::vector<std::string> end_tokens { "<|im_end|>" };
356353

357354
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
358-
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
359-
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
355+
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file(
356+
"models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
357+
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file(
358+
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
360359

361360
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
362361
test_template(tmpl, end_tokens, tool_call_message, tools,
@@ -369,11 +368,13 @@ static void test_template_output_parsers() {
369368
"</tool_call>");
370369
}
371370
{
372-
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
371+
const common_chat_template tmpl(read_file(
372+
"models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
373373
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
374374

375375
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(tmpl, tools_params));
376-
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
376+
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file(
377+
"models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
377378

378379
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
379380
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
@@ -384,7 +385,8 @@ static void test_template_output_parsers() {
384385
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
385386
}
386387
{
387-
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
388+
const common_chat_template tmpl(read_file(
389+
"models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
388390
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
389391

390392
assert_equals(std::string("llama 3.x tool calls"), describe(tmpl, tools_params));
@@ -395,7 +397,8 @@ static void test_template_output_parsers() {
395397
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
396398
}
397399
{
398-
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
400+
const common_chat_template tmpl(read_file(
401+
"models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
399402
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
400403

401404
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
@@ -406,7 +409,8 @@ static void test_template_output_parsers() {
406409
"<function=special_function>{\"arg1\": 1}</function>");
407410
}
408411
{
409-
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
412+
const common_chat_template tmpl(read_file(
413+
"models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
410414
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
411415

412416
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
@@ -420,7 +424,8 @@ static void test_template_output_parsers() {
420424
"{\"arg1\": 1}");
421425
}
422426
{
423-
const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
427+
const common_chat_template tmpl(read_file(
428+
"models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
424429
std::vector<std::string> end_tokens { "<|eot_id|>" };
425430

426431
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
@@ -431,7 +436,8 @@ static void test_template_output_parsers() {
431436
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
432437
}
433438
{
434-
const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
439+
const common_chat_template tmpl(read_file(
440+
"models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
435441
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
436442

437443
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));

0 commit comments

Comments
 (0)