Skip to content

Commit b37779b

Browse files
author
ochafik
committed
sampler: turn lazy grammar trigger words to regexes
Update llama-grammar.h update Update llama-grammar.h Update common.h Update common.h Update sampling.cpp Update chat.cpp update test_tool_call.py Update server.cpp Update utils.hpp Update chat.cpp Update test_tool_call.py Update fetch_server_test_models.py
1 parent 63e489c commit b37779b

File tree

13 files changed

+697
-289
lines changed

13 files changed

+697
-289
lines changed

common/chat.cpp

Lines changed: 294 additions & 130 deletions
Large diffs are not rendered by default.

common/common.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,11 @@ void string_replace_all(std::string & s, const std::string & search, const std::
483483
s = std::move(builder);
484484
}
485485

486+
std::string regex_escape(const std::string & s) {
487+
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
488+
return std::regex_replace(s, special_chars, "\\$0");
489+
}
490+
486491
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
487492
std::ostringstream result;
488493
for (size_t i = 0; i < values.size(); ++i) {

common/common.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,16 @@ enum common_conversation_mode {
110110
COMMON_CONVERSATION_MODE_AUTO = 2,
111111
};
112112

113+
enum common_grammar_trigger_type {
114+
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
115+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
116+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
117+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
118+
};
119+
113120
struct common_grammar_trigger {
114-
std::string word;
115-
bool at_start;
121+
common_grammar_trigger_type type;
122+
std::variant<llama_token, std::string> value;
116123
};
117124

118125
// sampling parameters
@@ -163,8 +170,7 @@ struct common_params_sampling {
163170

164171
std::string grammar; // optional BNF-like grammar to constrain sampling
165172
bool grammar_lazy = false;
166-
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
167-
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
173+
std::vector<common_grammar_trigger> grammar_triggers; // optional trigger words to trigger lazy grammar
168174
std::set<llama_token> preserved_tokens;
169175

170176
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
@@ -453,6 +459,8 @@ std::string string_repeat(const std::string & str, size_t n);
453459

454460
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
455461

462+
std::string regex_escape(const std::string & s);
463+
456464
template<class T>
457465
static std::vector<T> string_split(const std::string & str, char delim) {
458466
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");

common/sampling.cpp

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,16 +159,53 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
159159
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
160160
#endif // LLAMA_USE_LLGUIDANCE
161161
} else {
162-
std::vector<const char *> trigger_words;
163-
trigger_words.reserve(params.grammar_trigger_words.size());
164-
for (const auto & str : params.grammar_trigger_words) {
165-
trigger_words.push_back(str.word.c_str());
162+
std::vector<std::string> patterns_at_start;
163+
std::vector<std::string> patterns_anywhere;
164+
std::vector<llama_token> trigger_tokens;
165+
for (const auto & trigger : params.grammar_triggers) {
166+
switch (trigger.type) {
167+
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
168+
{
169+
const auto & word = std::get<std::string>(trigger.value);
170+
patterns_anywhere.push_back(regex_escape(word));
171+
break;
172+
}
173+
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
174+
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
175+
{
176+
const auto & pattern = std::get<std::string>(trigger.value);
177+
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
178+
break;
179+
}
180+
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
181+
{
182+
const auto & token = std::get<llama_token>(trigger.value);
183+
trigger_tokens.push_back(token);
184+
break;
185+
}
186+
default:
187+
GGML_ASSERT(false && "unknown trigger type");
188+
}
189+
}
190+
191+
std::vector<std::string> trigger_patterns;
192+
if (!patterns_at_start.empty()) {
193+
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
194+
}
195+
if (!patterns_anywhere.empty()) {
196+
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
197+
}
198+
199+
std::vector<const char *> trigger_patterns_c;
200+
trigger_patterns_c.reserve(trigger_patterns.size());
201+
for (const auto & regex : trigger_patterns) {
202+
trigger_patterns_c.push_back(regex.c_str());
166203
}
167204

168205
grmr = params.grammar_lazy
169-
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
170-
trigger_words.data(), trigger_words.size(),
171-
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
206+
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
207+
trigger_patterns_c.data(), trigger_patterns_c.size(),
208+
trigger_tokens.data(), trigger_tokens.size())
172209
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
173210
}
174211

examples/server/server.cpp

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,22 @@ struct slot_params {
131131
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
132132
}
133133

134-
std::vector<std::string> grammar_trigger_words;
135-
for (const auto & trigger : sampling.grammar_trigger_words) {
136-
grammar_trigger_words.push_back(trigger.word);
134+
auto grammar_triggers = json::array();
135+
for (const auto & trigger : sampling.grammar_triggers) {
136+
switch (trigger.type) {
137+
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
138+
grammar_triggers.push_back({{"word", std::get<std::string>(trigger.value)}});
139+
break;
140+
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
141+
grammar_triggers.push_back({{"pattern", std::get<std::string>(trigger.value)}});
142+
break;
143+
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
144+
grammar_triggers.push_back({{"pattern_start", std::get<std::string>(trigger.value)}});
145+
break;
146+
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
147+
grammar_triggers.push_back({{"token", std::get<llama_token>(trigger.value)}});
148+
break;
149+
}
137150
}
138151

139152
return json {
@@ -170,8 +183,8 @@ struct slot_params {
170183
{"n_probs", sampling.n_probs},
171184
{"min_keep", sampling.min_keep},
172185
{"grammar", sampling.grammar},
173-
{"grammar_trigger_words", grammar_trigger_words},
174-
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
186+
{"grammar_lazy", sampling.grammar_lazy},
187+
{"grammar_triggers", grammar_triggers},
175188
{"preserved_tokens", sampling.preserved_tokens},
176189
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
177190
{"samplers", samplers},
@@ -356,24 +369,6 @@ struct server_task {
356369
}
357370

358371
{
359-
const auto grammar_triggers = data.find("grammar_triggers");
360-
if (grammar_triggers != data.end()) {
361-
for (const auto & t : *grammar_triggers) {
362-
common_grammar_trigger trigger;
363-
trigger.word = t.at("word");
364-
trigger.at_start = t.at("at_start");
365-
366-
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
367-
if (ids.size() == 1) {
368-
SRV_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
369-
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
370-
params.sampling.preserved_tokens.insert(ids[0]);
371-
continue;
372-
}
373-
SRV_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
374-
params.sampling.grammar_trigger_words.push_back(trigger);
375-
}
376-
}
377372
const auto preserved_tokens = data.find("preserved_tokens");
378373
if (preserved_tokens != data.end()) {
379374
for (const auto & t : *preserved_tokens) {
@@ -383,12 +378,48 @@ struct server_task {
383378
params.sampling.preserved_tokens.insert(ids[0]);
384379
} else {
385380
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
386-
SRV_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
381+
SRV_DBG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
382+
}
383+
}
384+
}
385+
const auto grammar_triggers = data.find("grammar_triggers");
386+
if (grammar_triggers != data.end()) {
387+
for (const auto & t : *grammar_triggers) {
388+
auto type = static_cast<common_grammar_trigger_type>(t.at("type"));
389+
switch (type) {
390+
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
391+
{
392+
const std::string & word = t.at("value");
393+
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
394+
if (ids.size() == 1) {
395+
auto token = ids[0];
396+
if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), token) == params.sampling.preserved_tokens.end()) {
397+
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
398+
}
399+
SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
400+
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN, token});
401+
} else {
402+
SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
403+
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
404+
}
405+
break;
406+
}
407+
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
408+
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
409+
{
410+
const std::string & pattern = t.at("value");
411+
params.sampling.grammar_triggers.push_back({type, pattern});
412+
break;
413+
}
414+
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
415+
throw std::runtime_error("Unespected token trigger");
416+
default:
417+
throw std::runtime_error("Unknown trigger type");
387418
}
388419
}
389420
}
390421
if (params.sampling.grammar_lazy) {
391-
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
422+
GGML_ASSERT(params.sampling.grammar_triggers.size() > 0);
392423
}
393424
}
394425

@@ -2045,7 +2076,7 @@ struct server_context {
20452076

20462077
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
20472078
// Might be better to reject the request with a 400 ?
2048-
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict);
2079+
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict);
20492080
slot.params.n_predict = slot.n_predict;
20502081
}
20512082

0 commit comments

Comments
 (0)