Skip to content

Commit 90effb8

Browse files
author
ochafik
committed
Pass grammar laziness all the way down to sampler (need to print special trigger tokens e.g. for Nemo even w/ tool_choice=required)
1 parent ad22978 commit 90effb8

File tree

12 files changed

+62
-49
lines changed

12 files changed

+62
-49
lines changed

common/chat-handler.cpp

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
279279
}
280280
: tool_call;
281281

282+
data.grammar_lazy = false;
282283
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
283284
builder.add_schema("root", schema);
284285
}, grammar_options);
@@ -319,6 +320,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
319320
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
320321
fprintf(stderr, "[%s]\n", __func__);
321322
common_chat_data data;
323+
data.grammar_lazy = params.tool_choice != "required";
322324
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
323325
auto schemas = json::array();
324326
foreach_function(params.tools, [&](const json & tool) {
@@ -352,9 +354,7 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
352354
}
353355
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
354356
}, grammar_options);
355-
if (params.tool_choice != "required") {
356-
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
357-
}
357+
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
358358
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
359359
data.format = "mistral nemo tool calls";
360360
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
@@ -369,6 +369,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
369369
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
370370
common_chat_data data;
371371

372+
data.grammar_lazy = params.tool_choice != "required";
372373
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
373374
std::vector<std::string> tool_rules;
374375

@@ -385,14 +386,10 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
385386
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
386387
builder.add_schema(name + "-args", parameters) +
387388
" \"}\""));
388-
if (params.tool_choice != "required") {
389-
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
390-
}
389+
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
391390
});
392391
tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*"));
393-
if (params.tool_choice != "required") {
394-
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
395-
}
392+
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
396393
builder.add_rule("root", string_join(tool_rules, " | "));
397394
}, grammar_options);
398395
data.additional_stops.push_back("<|eom_id|>");
@@ -429,6 +426,7 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
429426
fprintf(stderr, "[%s]\n", __func__);
430427
common_chat_data data;
431428

429+
data.grammar_lazy = params.tool_choice != "required";
432430
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
433431
std::vector<std::string> tool_rules;
434432

@@ -446,9 +444,7 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
446444
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
447445
builder.add_schema(name + "-args", parameters) +
448446
" \"}\""));
449-
if (params.tool_choice != "required") {
450-
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
451-
}
447+
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
452448
});
453449

454450
builder.add_rule("root", string_join(tool_rules, " | "));
@@ -468,8 +464,7 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
468464
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
469465
fprintf(stderr, "[%s]\n", __func__);
470466
common_chat_data data;
471-
data.grammar = "root ::= .*";
472-
// data.grammar = "root ::= .*";
467+
data.grammar_lazy = params.tool_choice != "required";
473468
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
474469
std::vector<std::string> tool_rules;
475470
foreach_function(params.tools, [&](const json & tool) {
@@ -480,9 +475,7 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
480475
tool_rules.push_back(builder.add_rule(name + "-call",
481476
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
482477
});
483-
if (params.tool_choice != "required") {
484-
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
485-
}
478+
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
486479
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space");
487480
}, grammar_options);
488481
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
@@ -499,6 +492,7 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
499492
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
500493
fprintf(stderr, "[%s]\n", __func__);
501494
common_chat_data data;
495+
data.grammar_lazy = params.tool_choice != "required";
502496
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
503497
auto schemas = json::array();
504498
foreach_function(params.tools, [&](const json & tool) {
@@ -525,9 +519,7 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
525519
}
526520
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
527521
}, grammar_options);
528-
if (params.tool_choice != "required") {
529-
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
530-
}
522+
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
531523
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
532524
data.format = "firefunction v2 tool calls";
533525
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
@@ -542,6 +534,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
542534
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
543535
common_chat_data data;
544536

537+
data.grammar_lazy = params.tool_choice != "required";
545538
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
546539
std::vector<std::string> first_tool_rules;
547540
std::vector<std::string> subsequent_tool_rules;
@@ -552,10 +545,8 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
552545
auto args_rule = builder.add_schema(name + "-args", parameters);
553546
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
554547
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
555-
if (params.tool_choice != "required") {
556-
data.grammar_triggers.push_back({name, /* .at_start = */ true});
557-
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
558-
}
548+
data.grammar_triggers.push_back({name, /* .at_start = */ true});
549+
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
559550
});
560551
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
561552
if (params.parallel_tool_calls) {
@@ -591,6 +582,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
591582
std::string python_code_argument_name;
592583
auto has_raw_python = false;
593584

585+
data.grammar_lazy = params.tool_choice != "required";
594586
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
595587
std::vector<std::string> tool_rules;
596588
foreach_function(params.tools, [&](const json & tool) {
@@ -624,15 +616,11 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
624616
});
625617
if (has_raw_python) {
626618
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
627-
if (params.tool_choice != "required") {
628-
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
629-
}
619+
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
630620
}
631621
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
632622
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
633-
if (params.tool_choice != "required") {
634-
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
635-
}
623+
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
636624
}, grammar_options);
637625

638626
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
@@ -666,6 +654,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
666654
fprintf(stderr, "[%s]\n", __func__);
667655
common_chat_data data;
668656
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
657+
data.grammar_lazy = params.tool_choice != "required";
669658
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
670659
std::vector<std::string> tool_rules;
671660
foreach_function(params.tools, [&](const json & tool) {
@@ -684,9 +673,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
684673
});
685674
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
686675
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
687-
if (params.tool_choice != "required") {
688-
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
689-
}
676+
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
690677
}, grammar_options);
691678

692679
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
@@ -701,7 +688,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
701688
std::sregex_iterator rend;
702689
std::sregex_iterator rit(input.begin(), end, start_pattern);
703690
if (rit == rend) {
704-
return {"assistant", input, {}};
691+
return {
692+
/* .role = */ "assistant",
693+
/* .content = */ input,
694+
/* .tool_calls = */ {},
695+
};
705696
}
706697

707698
common_chat_msg result;
@@ -732,7 +723,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
732723
}
733724
return result;
734725
} catch (const std::exception & e) {
735-
return {"assistant", input, {}};
726+
return {
727+
/* .role = */ "assistant",
728+
/* .content = */ input,
729+
/* .tool_calls = */ {},
730+
};
736731
}
737732
});
738733
return data;
@@ -744,6 +739,7 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
744739
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
745740
data.format = "content-only";
746741
data.parser = std::make_unique<text_chat_parser>();
742+
data.grammar_lazy = false;
747743
if (!params.json_schema.is_null()) {
748744
if (!params.grammar.empty()) {
749745
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");

common/chat-handler.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ struct common_chat_data {
4242
std::vector<std::string> additional_stops;
4343
std::unique_ptr<class common_chat_parser> parser;
4444
std::string format; // For debugging and testing.
45+
bool grammar_lazy = false;
4546
};
4647

4748
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);

common/common.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,9 @@ struct common_params_sampling {
160160
};
161161

162162
std::string grammar; // optional BNF-like grammar to constrain sampling
163-
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to enable grammar
164-
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to enable grammar
163+
bool grammar_lazy;
164+
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
165+
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
165166

166167
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
167168

common/sampling.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
159159
auto * result = new common_sampler {
160160
/* .params = */ params,
161161
/* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root",
162+
params.grammar_lazy,
162163
trigger_words.data(), trigger_words.size(),
163164
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()),
164165
/* .chain = */ llama_sampler_chain_init(lparams),

examples/gbnf-validator/gbnf-validator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ int main(int argc, char** argv) {
7676
grammar_str = buffer.str();
7777
}
7878

79-
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0);
79+
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0);
8080
if (grammar == nullptr) {
8181
fprintf(stdout, "Failed to initialize llama_grammar\n");
8282
return 1;

examples/server/server.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3816,6 +3816,7 @@ int main(int argc, char ** argv) {
38163816
task.params.oaicompat = oaicompat;
38173817
task.params.oaicompat_cmpl_id = completion_id;
38183818
task.params.sampling.grammar = chat_data.grammar;
3819+
task.params.sampling.grammar_lazy = chat_data.grammar_lazy;
38193820
for (const auto & trigger : chat_data.grammar_triggers) {
38203821
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
38213822
if (ids.size() == 1) {
@@ -3830,6 +3831,9 @@ int main(int argc, char ** argv) {
38303831
if (chat_data.parser) {
38313832
task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone());
38323833
}
3834+
if (task.params.sampling.grammar_lazy) {
3835+
GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0);
3836+
}
38333837
// oaicompat_model is already populated by params_from_json_cmpl
38343838

38353839
tasks.push_back(task);

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,7 @@ extern "C" {
11981198
const struct llama_vocab * vocab,
11991199
const char * grammar_str,
12001200
const char * grammar_root,
1201+
bool lazy,
12011202
const char ** trigger_words,
12021203
size_t num_trigger_words,
12031204
const llama_token * trigger_tokens,

src/llama-grammar.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,8 @@ struct llama_grammar * llama_grammar_init_impl(
964964
vocab,
965965
std::move(vec_rules),
966966
std::move(stacks),
967-
/* .partial_utf8 = */ {},
967+
/* .partial_utf8 = */ {},
968+
/* .lazy =*/ false,
968969
/* .awaiting_trigger = */ false,
969970
/* .trigger_buffer = */ "",
970971
/* .trigger_tokens = */ {},
@@ -976,6 +977,7 @@ struct llama_grammar * llama_grammar_init_impl(
976977
const struct llama_vocab * vocab,
977978
const char * grammar_str,
978979
const char * grammar_root,
980+
bool lazy,
979981
const char ** trigger_words,
980982
size_t num_trigger_words,
981983
const llama_token * trigger_tokens,
@@ -1069,8 +1071,9 @@ struct llama_grammar * llama_grammar_init_impl(
10691071
vocab,
10701072
std::move(vec_rules),
10711073
std::move(stacks),
1072-
/* .partial_utf8 = */ {},
1073-
/* .awaiting_trigger = */ vec_trigger_tokens.size() > 0 || vec_trigger_words.size() > 0,
1074+
/* .partial_utf8 = */ {},
1075+
/* .lazy = */ lazy,
1076+
/* .awaiting_trigger = */ lazy,
10741077
/* .trigger_buffer = */ "",
10751078
std::move(vec_trigger_tokens),
10761079
std::move(vec_trigger_words),
@@ -1091,6 +1094,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
10911094
grammar.rules,
10921095
grammar.stacks,
10931096
grammar.partial_utf8,
1097+
grammar.lazy,
10941098
grammar.awaiting_trigger,
10951099
grammar.trigger_buffer,
10961100
grammar.trigger_tokens,

src/llama-grammar.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,12 @@ struct llama_grammar {
116116
llama_partial_utf8 partial_utf8;
117117

118118
// lazy grammars wait for trigger words or tokens before constraining the sampling.
119-
bool awaiting_trigger;
120-
std::string trigger_buffer;
121-
std::vector<llama_token> trigger_tokens;
119+
// we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
120+
// (useful e.g. for tool_choice=required)
121+
bool lazy; // Useful when resetting
122+
bool awaiting_trigger; // Initialized to lazy
123+
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
124+
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
122125
std::vector<std::string> trigger_words;
123126
};
124127

@@ -137,6 +140,7 @@ struct llama_grammar * llama_grammar_init_impl(
137140
const struct llama_vocab * vocab,
138141
const char * grammar_str,
139142
const char * grammar_root,
143+
bool lazy,
140144
const char ** trigger_words,
141145
size_t num_trigger_words,
142146
const llama_token * trigger_tokens,

0 commit comments

Comments
 (0)