Skip to content

Commit 1317a35

Browse files
author
ochafik
committed
add common_grammar_trigger.{to_json,from_json}
1 parent 5d43b72 commit 1317a35

File tree

4 files changed

+50
-53
lines changed

4 files changed

+50
-53
lines changed

common/common.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
// Change JSON_ASSERT from assert() to GGML_ASSERT:
1111
#define JSON_ASSERT GGML_ASSERT
1212
#include "json.hpp"
13-
#include "json-schema-to-grammar.h"
1413
#include "llama.h"
1514

1615
#include <algorithm>
@@ -2031,3 +2030,25 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
20312030
return result;
20322031
}
20332032

2033+
template <>
2034+
json common_grammar_trigger::to_json() const {
2035+
json out {
2036+
{"type", (int) type},
2037+
{"value", value},
2038+
};
2039+
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
2040+
out["token"] = (int) token;
2041+
}
2042+
return out;
2043+
}
2044+
2045+
template <>
2046+
common_grammar_trigger common_grammar_trigger::from_json(const json & in) {
2047+
common_grammar_trigger out;
2048+
out.type = (common_grammar_trigger_type) in.at("type").get<int>();
2049+
out.value = in.at("value").get<std::string>();
2050+
if (out.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
2051+
out.token = (llama_token) in.at("token").get<int>();
2052+
}
2053+
return out;
2054+
}

common/common.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include <string>
99
#include <vector>
1010
#include <sstream>
11-
#include <variant>
1211

1312
#ifdef _WIN32
1413
#define DIRECTORY_SEPARATOR '\\'
@@ -122,6 +121,12 @@ struct common_grammar_trigger {
122121
common_grammar_trigger_type type;
123122
std::string value;
124123
llama_token token = LLAMA_TOKEN_NULL;
124+
125+
template <class T>
126+
T to_json() const;
127+
128+
template <class T>
129+
static common_grammar_trigger from_json(const T & in);
125130
};
126131

127132
// sampling parameters
@@ -172,7 +177,7 @@ struct common_params_sampling {
172177

173178
std::string grammar; // optional BNF-like grammar to constrain sampling
174179
bool grammar_lazy = false;
175-
std::vector<common_grammar_trigger> grammar_triggers; // optional trigger words to trigger lazy grammar
180+
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
176181
std::set<llama_token> preserved_tokens;
177182

178183
std::vector<llama_logit_bias> logit_bias; // logit biases to apply

examples/server/server.cpp

Lines changed: 20 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -133,20 +133,7 @@ struct slot_params {
133133

134134
auto grammar_triggers = json::array();
135135
for (const auto & trigger : sampling.grammar_triggers) {
136-
switch (trigger.type) {
137-
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
138-
grammar_triggers.push_back({{"word", trigger.value}});
139-
break;
140-
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
141-
grammar_triggers.push_back({{"pattern", trigger.value}});
142-
break;
143-
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
144-
grammar_triggers.push_back({{"pattern_start", trigger.value}});
145-
break;
146-
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
147-
grammar_triggers.push_back({{"token", trigger.token}});
148-
break;
149-
}
136+
grammar_triggers.push_back(trigger.to_json<json>());
150137
}
151138

152139
return json {
@@ -385,44 +372,31 @@ struct server_task {
385372
const auto grammar_triggers = data.find("grammar_triggers");
386373
if (grammar_triggers != data.end()) {
387374
for (const auto & t : *grammar_triggers) {
388-
auto type = static_cast<common_grammar_trigger_type>(t.at("type"));
389-
switch (type) {
390-
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
391-
{
392-
const std::string & word = t.at("value");
393-
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
394-
if (ids.size() == 1) {
395-
auto token = ids[0];
396-
if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
397-
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
398-
}
399-
SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
400-
common_grammar_trigger trigger;
401-
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
402-
trigger.value = token;
403-
params.sampling.grammar_triggers.push_back(trigger);
404-
} else {
405-
SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
406-
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
375+
auto ct = common_grammar_trigger::from_json(t);
376+
if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
377+
const auto & word = ct.value;
378+
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
379+
if (ids.size() == 1) {
380+
auto token = ids[0];
381+
if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
382+
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
407383
}
408-
break;
409-
}
410-
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
411-
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
412-
{
413-
const std::string & pattern = t.at("value");
414-
params.sampling.grammar_triggers.push_back({type, pattern});
415-
break;
384+
SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
385+
common_grammar_trigger trigger;
386+
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
387+
trigger.value = (llama_token) token;
388+
params.sampling.grammar_triggers.push_back(trigger);
389+
} else {
390+
SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
391+
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
416392
}
417-
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
418-
throw std::runtime_error("Unespected token trigger");
419-
default:
420-
throw std::runtime_error("Unknown trigger type");
393+
} else {
394+
params.sampling.grammar_triggers.push_back(ct);
421395
}
422396
}
423397
}
424398
if (params.sampling.grammar_lazy) {
425-
GGML_ASSERT(params.sampling.grammar_triggers.size() > 0);
399+
GGML_ASSERT(!params.sampling.grammar_triggers.empty());
426400
}
427401
}
428402

examples/server/utils.hpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -621,10 +621,7 @@ static json oaicompat_completion_params_parse(
621621
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
622622
auto grammar_triggers = json::array();
623623
for (const auto & trigger : chat_params.grammar_triggers) {
624-
grammar_triggers.push_back({
625-
{"type", (int) trigger.type},
626-
{"value", trigger.token},
627-
});
624+
grammar_triggers.push_back(trigger.to_json<json>());
628625
}
629626
llama_params["grammar_triggers"] = grammar_triggers;
630627
llama_params["preserved_tokens"] = chat_params.preserved_tokens;

0 commit comments

Comments
 (0)