Skip to content

Commit 923c805

Browse files
author
Olivier Chafik
committed
rm dead code + nits
1 parent 18d5a1b commit 923c805

File tree

7 files changed

+25
-35
lines changed

7 files changed

+25
-35
lines changed

examples/server/server.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,6 @@ struct server_task_result_cmpl_partial : server_task_result {
768768
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
769769
std::string oaicompat_model;
770770
std::string oaicompat_cmpl_id;
771-
std::shared_ptr<common_chat_parser> chat_parser;
772771

773772
virtual int get_index() override {
774773
return index;
@@ -1191,7 +1190,6 @@ struct server_slot {
11911190

11921191
std::string stopping_word;
11931192

1194-
std::shared_ptr<common_chat_parser> chat_parser;
11951193

11961194
// sampling
11971195
json json_schema;
@@ -1200,6 +1198,8 @@ struct server_slot {
12001198

12011199
llama_token sampled;
12021200

1201+
common_chat_parser chat_parser;
1202+
12031203
// stats
12041204
size_t n_sent_text = 0; // number of sent text character
12051205

@@ -3998,8 +3998,6 @@ int main(int argc, char ** argv) {
39983998

39993999
auto body = json::parse(req.body);
40004000
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
4001-
LOG_INF("Request: %s\n", body.dump(2).c_str());
4002-
40034001
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
40044002

40054003
return handle_completions_impl(

examples/server/tests/unit/test_tool_call.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
154154
if template_override:
155155
(template_hf_repo, template_variant) = template_override
156156
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
157-
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
157+
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
158158
server.start()
159159
res = server.make_request("POST", "/chat/completions", data={
160160
"max_tokens": n_predict,
@@ -243,7 +243,7 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
243243
if template_override:
244244
(template_hf_repo, template_variant) = template_override
245245
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
246-
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
246+
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
247247
server.start(timeout_seconds=15*60)
248248
res = server.make_request("POST", "/chat/completions", data={
249249
"max_tokens": 256,
@@ -292,7 +292,7 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_
292292
if template_override:
293293
(template_hf_repo, template_variant) = template_override
294294
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
295-
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
295+
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
296296
server.start(timeout_seconds=15*60)
297297
res = server.make_request("POST", "/chat/completions", data={
298298
"max_tokens": 256,

examples/server/utils.hpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,11 @@ static json oaicompat_completion_params_parse(
596596
throw std::runtime_error("tools param requires --jinja flag");
597597
}
598598
}
599+
if (!use_jinja) {
600+
if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
601+
throw std::runtime_error("Unsupported param: tool_choice");
602+
}
603+
}
599604

600605
// Handle "stop" field
601606
if (body.contains("stop") && body.at("stop").is_string()) {
@@ -605,7 +610,6 @@ static json oaicompat_completion_params_parse(
605610
}
606611

607612
// Handle "response_format" field
608-
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
609613
if (body.contains("response_format")) {
610614
json response_format = json_value(body, "response_format", json::object());
611615
std::string response_type = json_value(response_format, "type", std::string());
@@ -649,16 +653,6 @@ static json oaicompat_completion_params_parse(
649653
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
650654
}
651655

652-
// Params supported by OAI but unsupported by llama.cpp
653-
if (!use_jinja) {
654-
static const std::vector<std::string> unsupported_params { "tool_choice" };
655-
for (const auto & param : unsupported_params) {
656-
if (body.contains(param)) {
657-
throw std::runtime_error("Unsupported param: " + param);
658-
}
659-
}
660-
}
661-
662656
// Copy remaining properties to llama_params
663657
// This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint.
664658
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp

scripts/get_hf_chat_template.py renamed to scripts/get_chat_template.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,20 @@
44
If a model has multiple chat templates, you can specify the variant name.
55
66
Syntax:
7-
./scripts/get_hf_chat_template.py model_id [variant]
7+
./scripts/get_chat_template.py model_id [variant]
88
99
Examples:
10-
./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct
11-
./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use
12-
./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct
10+
./scripts/get_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct
11+
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use
12+
./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct
1313
'''
1414

1515
import json
1616
import re
1717
import sys
1818

1919

20-
def get_hf_chat_template(model_id, variant=None):
20+
def get_chat_template(model_id, variant=None):
2121
try:
2222
# Use huggingface_hub library if available.
2323
# Allows access to gated models if the user has access and ran `huggingface-cli login`.
@@ -69,9 +69,10 @@ def main(args):
6969
model_id = args[0]
7070
variant = None if len(args) < 2 else args[1]
7171

72-
template = get_hf_chat_template(model_id, variant)
72+
template = get_chat_template(model_id, variant)
7373
sys.stdout.write(template)
7474

7575

7676
if __name__ == '__main__':
7777
main(sys.argv[1:])
78+

src/llama-grammar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) {
560560
}
561561
}
562562
} catch (const std::exception & err) {
563-
fprintf(stderr, "\n%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
563+
fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
564564
rules.clear();
565565
return false;
566566
}

src/llama-grammar.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ struct llama_grammar {
118118
// lazy grammars wait for trigger words or tokens before constraining the sampling.
119119
// we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
120120
// (useful e.g. for tool_choice=required)
121-
bool lazy; // Useful when resetting
122-
bool awaiting_trigger; // Initialized to lazy
121+
bool lazy;
122+
bool awaiting_trigger; // Initialized to true for lazy grammars only
123123
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
124124
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
125125
std::vector<std::string> trigger_words;

tests/test-chat-handler.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,6 @@ struct delta_data {
169169
};
170170

171171
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
172-
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
173-
fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str());
174-
175172
common_chat_params params;
176173
params.parallel_tool_calls = true;
177174
params.messages = json::array();
@@ -209,12 +206,14 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
209206
return {delta, full_data.grammar, full_data.parser};
210207
}
211208

209+
/*
210+
Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
211+
gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
212+
the parsed message is the same as the test_message
213+
*/
212214
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & test_message, const json & tools = {}, const std::string & expected_delta = "", bool skip_grammar_test = false, bool skip_parser_test = false) {
213-
// auto tool_call_style = common_tool_call_style_detect(tmpl);
214215
common_chat_msg expected_msg = msg_from_json(test_message);
215216

216-
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
217-
// get the diff and try and parse it w/ the grammar.
218217
auto user_message = json {
219218
{"role", "user"},
220219
{"content", "Hello, world!"}
@@ -228,7 +227,6 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
228227
params.tools = tools;
229228

230229
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools);
231-
std::cout << "Full delta:\n```\n" << data.delta << "\n```" << std::endl;
232230
if (!expected_delta.empty()) {
233231
assert_equals(expected_delta, data.delta);
234232
}
@@ -495,7 +493,6 @@ static void test_template_output_parsers() {
495493
}
496494

497495
int main() {
498-
// test_parsing();
499496
test_template_output_parsers();
500497

501498
std::cout << "\n[tool-call] All tests passed!" << std::endl;

0 commit comments

Comments
 (0)