Skip to content

Commit 43385b2

Browse files
author
ochafik
committed
sync: minja
1 parent 1159455 commit 43385b2

File tree

3 files changed

+126
-103
lines changed

3 files changed

+126
-103
lines changed

common/chat-template.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class chat_template {
6161
});
6262
supports_tools_ = source.find("tools") != std::string::npos;
6363

64-
requires_object_arguments_ =
64+
requires_object_arguments_ =
6565
try_raw_render({
6666
{
6767
{"role", "user"},
@@ -298,7 +298,7 @@ class chat_template {
298298
if (!tools.is_null()) {
299299
auto tools_val = minja::Value(actual_tools);
300300
context->set("tools", tools_val);
301-
if (has_code_interpreter) {
301+
if (has_code_interpreter && !extra_context.contains("builtin_tools")) {
302302
auto builtin_tools_val = minja::Value(json {"code_interpreter"});
303303
context->set("builtin_tools", builtin_tools_val);
304304
}

common/minja.hpp

Lines changed: 60 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2648,31 +2648,34 @@ inline std::shared_ptr<Context> Context::builtins() {
26482648
return filter.call(context, actual_args);
26492649
});
26502650
};
2651-
// https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject
2652-
globals.set("reject", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
2653-
args.expectArgs("reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
2654-
auto & items = args.args[0];
2655-
auto filter_fn = context->get(args.args[1]);
2656-
if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
2651+
auto select_or_reject = [make_filter](bool is_select) {
2652+
return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
2653+
args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
2654+
auto & items = args.args[0];
2655+
auto filter_fn = context->get(args.args[1]);
2656+
if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
26572657

2658-
auto filter_args = Value::array();
2659-
for (size_t i = 2, n = args.args.size(); i < n; i++) {
2660-
filter_args.push_back(args.args[i]);
2661-
}
2662-
auto filter = make_filter(filter_fn, filter_args);
2658+
auto filter_args = Value::array();
2659+
for (size_t i = 2, n = args.args.size(); i < n; i++) {
2660+
filter_args.push_back(args.args[i]);
2661+
}
2662+
auto filter = make_filter(filter_fn, filter_args);
26632663

2664-
auto res = Value::array();
2665-
for (size_t i = 0, n = items.size(); i < n; i++) {
2666-
auto & item = items.at(i);
2667-
ArgumentsValue filter_args;
2668-
filter_args.args.emplace_back(item);
2669-
auto pred_res = filter.call(context, filter_args);
2670-
if (!pred_res.to_bool()) {
2671-
res.push_back(item);
2664+
auto res = Value::array();
2665+
for (size_t i = 0, n = items.size(); i < n; i++) {
2666+
auto & item = items.at(i);
2667+
ArgumentsValue filter_args;
2668+
filter_args.args.emplace_back(item);
2669+
auto pred_res = filter.call(context, filter_args);
2670+
if (pred_res.to_bool() == (is_select ? true : false)) {
2671+
res.push_back(item);
2672+
}
26722673
}
2673-
}
2674-
return res;
2675-
}));
2674+
return res;
2675+
});
2676+
};
2677+
globals.set("select", select_or_reject(/* is_select= */ true));
2678+
globals.set("reject", select_or_reject(/* is_select= */ false));
26762679
globals.set("map", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
26772680
auto res = Value::array();
26782681
if (args.args.size() == 1 &&
@@ -2720,41 +2723,45 @@ inline std::shared_ptr<Context> Context::builtins() {
27202723
if (!text.empty() && text.back() == '\n') out += "\n";
27212724
return out;
27222725
}));
2723-
globals.set("selectattr", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
2724-
args.expectArgs("selectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
2725-
auto & items = args.args[0];
2726-
if (items.is_null())
2727-
return Value::array();
2728-
auto attr_name = args.args[1].get<std::string>();
2729-
2730-
bool has_test = false;
2731-
Value test_fn;
2732-
ArgumentsValue test_args {{Value()}, {}};
2733-
if (args.args.size() >= 3) {
2734-
has_test = true;
2735-
test_fn = context->get(args.args[2]);
2736-
if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
2737-
for (size_t i = 3, n = args.args.size(); i < n; i++) {
2738-
test_args.args.emplace_back(args.args[i]);
2726+
auto select_or_reject_attr = [](bool is_select) {
2727+
return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
2728+
args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
2729+
auto & items = args.args[0];
2730+
if (items.is_null())
2731+
return Value::array();
2732+
auto attr_name = args.args[1].get<std::string>();
2733+
2734+
bool has_test = false;
2735+
Value test_fn;
2736+
ArgumentsValue test_args {{Value()}, {}};
2737+
if (args.args.size() >= 3) {
2738+
has_test = true;
2739+
test_fn = context->get(args.args[2]);
2740+
if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
2741+
for (size_t i = 3, n = args.args.size(); i < n; i++) {
2742+
test_args.args.emplace_back(args.args[i]);
2743+
}
2744+
test_args.kwargs = args.kwargs;
27392745
}
2740-
test_args.kwargs = args.kwargs;
2741-
}
27422746

2743-
auto res = Value::array();
2744-
for (size_t i = 0, n = items.size(); i < n; i++) {
2745-
auto & item = items.at(i);
2746-
auto attr = item.get(attr_name);
2747-
if (has_test) {
2748-
test_args.args[0] = attr;
2749-
if (test_fn.call(context, test_args).to_bool()) {
2750-
res.push_back(item);
2747+
auto res = Value::array();
2748+
for (size_t i = 0, n = items.size(); i < n; i++) {
2749+
auto & item = items.at(i);
2750+
auto attr = item.get(attr_name);
2751+
if (has_test) {
2752+
test_args.args[0] = attr;
2753+
if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) {
2754+
res.push_back(item);
2755+
}
2756+
} else {
2757+
res.push_back(attr);
27512758
}
2752-
} else {
2753-
res.push_back(attr);
27542759
}
2755-
}
2756-
return res;
2757-
}));
2760+
return res;
2761+
});
2762+
};
2763+
globals.set("selectattr", select_or_reject_attr(/* is_select= */ true));
2764+
globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false));
27582765
globals.set("range", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
27592766
std::vector<int64_t> startEndStep(3);
27602767
std::vector<bool> param_set(3);

examples/server/server.cpp

Lines changed: 64 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ struct server_task {
211211
static slot_params params_from_json_cmpl(
212212
const llama_context * ctx,
213213
const common_params & params_base,
214-
const common_chat_template * tmpl,
215214
const json & data) {
216215
const llama_model * model = llama_get_model(ctx);
217216
const llama_vocab * vocab = llama_model_get_vocab(model);
@@ -330,30 +329,19 @@ struct server_task {
330329
}
331330
}
332331

333-
if (tmpl && params_base.use_jinja) {
334-
common_chat_params chat_params;
335-
chat_params.messages = json_value(data, "messages", json::array());
336-
chat_params.tools = json_value(data, "tools", json());
337-
chat_params.tool_choice = json_value(data, "tool_choice", std::string("auto"));
338-
chat_params.json_schema = json_value(data, "json_schema", json());
339-
chat_params.parallel_tool_calls = json_value(data, "parallel_tool_calls", false);
340-
chat_params.stream = json_value(data, "stream", false);
341-
342-
auto chat_data = common_chat_init(*tmpl, chat_params);
343-
params.chat_parser = std::move(chat_data.handler);
344-
params.sampling.grammar = chat_data.grammar;
345-
for (const auto & stop : chat_data.additional_stops) {
346-
params.antiprompt.push_back(stop);
332+
if (!params_base.use_jinja) {
333+
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
334+
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
347335
}
348-
for (const auto & trigger : chat_data.grammar_triggers) {
349-
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
350-
if (ids.size() == 1) {
351-
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
352-
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
353-
continue;
336+
if (data.contains("json_schema") && !data.contains("grammar")) {
337+
try {
338+
auto schema = json_value(data, "json_schema", json::object());
339+
params.sampling.grammar = json_schema_to_grammar(schema);
340+
} catch (const std::exception & e) {
341+
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
354342
}
355-
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
356-
params.sampling.grammar_trigger_words.push_back(trigger);
343+
} else {
344+
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
357345
}
358346
}
359347

@@ -363,15 +351,13 @@ struct server_task {
363351
}
364352
if (data.contains("json_schema") && !data.contains("grammar")) {
365353
try {
366-
auto schema = json_value(data, "json_schema", json::object());
367-
params.sampling.grammar = json_schema_to_grammar(schema);
354+
params.sampling.grammar = json_schema_to_grammar(json_value(data, "json_schema", json::object()));
368355
} catch (const std::exception & e) {
369356
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
370357
}
371358
} else {
372359
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
373360
}
374-
LOG_INF("Grammar: %s\n", params.sampling.grammar.c_str());
375361

376362
{
377363
params.sampling.logit_bias.clear();
@@ -2248,9 +2234,15 @@ struct server_context {
22482234
}
22492235

22502236
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
2251-
auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send);
2252-
if (!opt_msg) {
2253-
return;
2237+
common_chat_msg msg;
2238+
if (slot.params.chat_parser) {
2239+
if (auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send)) {
2240+
msg = *opt_msg;
2241+
} else {
2242+
return;
2243+
}
2244+
} else {
2245+
msg.content = tkn.text_to_send;
22542246
}
22552247
auto res = std::make_unique<server_task_result_cmpl_partial>();
22562248

@@ -2267,7 +2259,7 @@ struct server_context {
22672259
res->oaicompat = slot.params.oaicompat;
22682260
res->oaicompat_model = slot.params.oaicompat_model;
22692261
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
2270-
res->oaicompat_chat_msg = *opt_msg;
2262+
res->oaicompat_chat_msg = msg;
22712263

22722264
// populate res.probs_output
22732265
if (slot.params.sampling.n_probs > 0) {
@@ -2308,7 +2300,11 @@ struct server_context {
23082300
res->oaicompat = slot.params.oaicompat;
23092301
res->oaicompat_model = slot.params.oaicompat_model;
23102302
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
2311-
res->oaicompat_chat_msg = slot.params.chat_parser->parse_final(slot.generated_text);
2303+
res->oaicompat_chat_msg = slot.params.chat_parser ? slot.params.chat_parser->parse_final(slot.generated_text) : common_chat_msg {
2304+
/* .role = */ "assistant",
2305+
/* .content = */ slot.generated_text,
2306+
/* .tool_calls = */ {}
2307+
};
23122308

23132309
// populate res.probs_output
23142310
if (slot.params.sampling.n_probs > 0) {
@@ -3773,7 +3769,7 @@ int main(int argc, char ** argv) {
37733769
std::function<bool()> is_connection_closed,
37743770
httplib::Response & res,
37753771
oaicompat_type oaicompat,
3776-
const common_chat_template * tmpl) {
3772+
const common_chat_template * tmpl = nullptr) {
37773773
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
37783774

37793775
if (ctx_server.params_base.embedding) {
@@ -3785,21 +3781,29 @@ int main(int argc, char ** argv) {
37853781
std::vector<server_task> tasks;
37863782

37873783
try {
3788-
fprintf(stderr, "PROMPT: %s\n", data.at("prompt").get<std::string>().c_str());
3789-
std::string prompt;
3784+
common_chat_data chat_data;
37903785
if (tmpl && ctx_server.params_base.use_jinja) {
3791-
auto chat_data = common_chat_init(*tmpl, {
3792-
/* .messages = */ json_data(data, "messages", json::array()),
3793-
/* .tools = */ json_data(data, "tools", json()),
3794-
/
3786+
chat_data = common_chat_init(*tmpl, {
3787+
/* .messages = */ json_value(data, "messages", json::array()),
3788+
/* .tools = */ json_value(data, "tools", json()),
3789+
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
3790+
/* .json_schema = */ json_value(data, "json_schema", json()),
3791+
/* .parallel_tool_calls = */ json_value(data, "json_schema", true),
3792+
/* .stream = */ json_value(data, "json_schema", false),
3793+
/* .grammar = */ json_value(data, "grammar", std::string("")),
37953794
});
3796-
3797-
prompt = ctx_server.chat_templates.template_default->render(data.at("prompt").get<std::string>());
3795+
if (data.contains("grammar")) {
3796+
chat_data.grammar = data.at("grammar");
3797+
}
37983798
} else {
3799-
prompt = data.at("prompt").get<std::string>();
3799+
chat_data.prompt = data.at("prompt");
3800+
if (data.contains("grammar")) {
3801+
chat_data.grammar = data.at("grammar");
3802+
} else if (data.contains("json_schema")) {
3803+
chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
3804+
}
38003805
}
3801-
task.params.chat_parser = common_chat_init()
3802-
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
3806+
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, true, true);
38033807
tasks.reserve(tokenized_prompts.size());
38043808
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
38053809
server_task task = server_task(type);
@@ -3811,16 +3815,27 @@ int main(int argc, char ** argv) {
38113815
task.params = server_task::params_from_json_cmpl(
38123816
ctx_server.ctx,
38133817
ctx_server.params_base,
3814-
nullptr,
38153818
data);
38163819
task.id_selected_slot = json_value(data, "id_slot", -1);
38173820

38183821
// OAI-compat
38193822
task.params.oaicompat = oaicompat;
38203823
task.params.oaicompat_cmpl_id = completion_id;
3821-
task.params.chat_parser = common_chat_init()
3822-
task.params.oaicompat_tools = json_value(data, "tools", json());
3823-
task.params.oaicompat_tool_call_style = tool_call_style;
3824+
task.params.sampling.grammar = chat_data.grammar;
3825+
for (const auto & trigger : chat_data.grammar_triggers) {
3826+
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
3827+
if (ids.size() == 1) {
3828+
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
3829+
task.params.sampling.grammar_trigger_tokens.push_back(ids[0]);
3830+
continue;
3831+
}
3832+
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
3833+
task.params.sampling.grammar_trigger_words.push_back(trigger);
3834+
}
3835+
task.params.antiprompt = chat_data.additional_stops;
3836+
if (chat_data.parser) {
3837+
task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone());
3838+
}
38243839
// oaicompat_model is already populated by params_from_json_cmpl
38253840

38263841
tasks.push_back(task);
@@ -4005,7 +4020,8 @@ int main(int argc, char ** argv) {
40054020
data,
40064021
req.is_connection_closed,
40074022
res,
4008-
OAICOMPAT_TYPE_CHAT);
4023+
OAICOMPAT_TYPE_CHAT,
4024+
&chat_template);
40094025
};
40104026

40114027
const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {

0 commit comments

Comments
 (0)