Skip to content

Commit e6e9c13

Browse files
author
ochafik
committed
common_grammar_trigger: always use string value (+ optional token)
1 parent 2470a1c commit e6e9c13

File tree

5 files changed

+13
-14
lines changed

5 files changed

+13
-14
lines changed

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ enum common_grammar_trigger_type {
120120

121121
struct common_grammar_trigger {
122122
common_grammar_trigger_type type;
123-
std::variant<llama_token, std::string> value;
123+
std::string value;
124+
llama_token token = LLAMA_TOKEN_NULL;
124125
};
125126

126127
// sampling parameters

common/sampling.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,20 +166,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
166166
switch (trigger.type) {
167167
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
168168
{
169-
const auto & word = std::get<std::string>(trigger.value);
169+
const auto & word = trigger.value;
170170
patterns_anywhere.push_back(regex_escape(word));
171171
break;
172172
}
173173
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
174174
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
175175
{
176-
const auto & pattern = std::get<std::string>(trigger.value);
176+
const auto & pattern = trigger.value;
177177
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
178178
break;
179179
}
180180
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
181181
{
182-
const auto & token = std::get<llama_token>(trigger.value);
182+
const auto token = trigger.token;
183183
trigger_tokens.push_back(token);
184184
break;
185185
}

examples/server/server.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,16 @@ struct slot_params {
135135
for (const auto & trigger : sampling.grammar_triggers) {
136136
switch (trigger.type) {
137137
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
138-
grammar_triggers.push_back({{"word", std::get<std::string>(trigger.value)}});
138+
grammar_triggers.push_back({{"word", trigger.value}});
139139
break;
140140
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
141-
grammar_triggers.push_back({{"pattern", std::get<std::string>(trigger.value)}});
141+
grammar_triggers.push_back({{"pattern", trigger.value}});
142142
break;
143143
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
144-
grammar_triggers.push_back({{"pattern_start", std::get<std::string>(trigger.value)}});
144+
grammar_triggers.push_back({{"pattern_start", trigger.value}});
145145
break;
146146
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
147-
grammar_triggers.push_back({{"token", std::get<llama_token>(trigger.value)}});
147+
grammar_triggers.push_back({{"token", trigger.token}});
148148
break;
149149
}
150150
}

examples/server/utils.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -623,9 +623,7 @@ static json oaicompat_completion_params_parse(
623623
for (const auto & trigger : chat_params.grammar_triggers) {
624624
grammar_triggers.push_back({
625625
{"type", (int) trigger.type},
626-
{"value", trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN
627-
? json((int) std::get<llama_token>(trigger.value))
628-
: json(std::get<std::string>(trigger.value))},
626+
{"value", trigger.token},
629627
});
630628
}
631629
llama_params["grammar_triggers"] = grammar_triggers;

tests/test-chat.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,21 +242,21 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
242242
switch (trigger.type) {
243243
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
244244
{
245-
const auto & word = std::get<std::string>(trigger.value);
245+
const auto & word = trigger.value;
246246
pos = constrained.find(word);
247247
break;
248248
}
249249
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
250250
{
251-
const auto & pattern = std::get<std::string>(trigger.value);
251+
const auto & pattern = trigger.value;
252252
if (std::regex_search(constrained, match, std::regex(pattern))) {
253253
pos = match.position();
254254
}
255255
break;
256256
}
257257
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
258258
{
259-
const auto & pattern = std::get<std::string>(trigger.value);
259+
const auto & pattern = trigger.value;
260260
if (std::regex_search(constrained, match, std::regex(pattern)) && match.position() == 0) {
261261
pos = 0;
262262
}

0 commit comments

Comments
 (0)