@@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) {
560
560
}
561
561
}
562
562
} catch (const std::exception & err) {
563
- fprintf (stderr, " %s: error parsing grammar: %s\n " , __func__, err.what ());
563
+ fprintf (stderr, " %s: error parsing grammar: %s\n\n %s \n " , __func__, err.what (), src );
564
564
rules.clear ();
565
565
return false ;
566
566
}
@@ -960,10 +960,28 @@ struct llama_grammar * llama_grammar_init_impl(
960
960
// Important: vec_rules has to be moved here, not copied, because stacks contains
961
961
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
962
962
// then the pointers would be invalidated when the local vec_rules goes out of scope.
963
- return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), {}, };
963
+ return new llama_grammar {
964
+ vocab,
965
+ std::move (vec_rules),
966
+ std::move (stacks),
967
+ /* .partial_utf8 = */ {},
968
+ /* .lazy =*/ false ,
969
+ /* .awaiting_trigger = */ false ,
970
+ /* .trigger_buffer = */ " " ,
971
+ /* .trigger_tokens = */ {},
972
+ /* .trigger_words = */ {},
973
+ };
964
974
}
965
975
966
- struct llama_grammar * llama_grammar_init_impl (const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
976
+ struct llama_grammar * llama_grammar_init_impl (
977
+ const struct llama_vocab * vocab,
978
+ const char * grammar_str,
979
+ const char * grammar_root,
980
+ bool lazy,
981
+ const char ** trigger_words,
982
+ size_t num_trigger_words,
983
+ const llama_token * trigger_tokens,
984
+ size_t num_trigger_tokens) {
967
985
llama_grammar_parser parser;
968
986
969
987
// if there is a grammar, parse it
@@ -1035,10 +1053,31 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
1035
1053
}
1036
1054
} while (true );
1037
1055
1056
+ std::vector<llama_token> vec_trigger_tokens;
1057
+ std::vector<std::string> vec_trigger_words;
1058
+ for (size_t i = 0 ; i < num_trigger_tokens; i++) {
1059
+ GGML_ASSERT (trigger_tokens != nullptr );
1060
+ vec_trigger_tokens.push_back (trigger_tokens[i]);
1061
+ }
1062
+ for (size_t i = 0 ; i < num_trigger_words; i++) {
1063
+ GGML_ASSERT (trigger_words != nullptr );
1064
+ vec_trigger_words.push_back (trigger_words[i]);
1065
+ }
1066
+
1038
1067
// Important: vec_rules has to be moved here, not copied, because stacks contains
1039
1068
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1040
1069
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1041
- return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), {}, };
1070
+ return new llama_grammar {
1071
+ vocab,
1072
+ std::move (vec_rules),
1073
+ std::move (stacks),
1074
+ /* .partial_utf8 = */ {},
1075
+ /* .lazy = */ lazy,
1076
+ /* .awaiting_trigger = */ lazy,
1077
+ /* .trigger_buffer = */ " " ,
1078
+ std::move (vec_trigger_tokens),
1079
+ std::move (vec_trigger_words),
1080
+ };
1042
1081
}
1043
1082
1044
1083
void llama_grammar_free_impl (struct llama_grammar * grammar) {
@@ -1055,6 +1094,11 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
1055
1094
grammar.rules ,
1056
1095
grammar.stacks ,
1057
1096
grammar.partial_utf8 ,
1097
+ grammar.lazy ,
1098
+ grammar.awaiting_trigger ,
1099
+ grammar.trigger_buffer ,
1100
+ grammar.trigger_tokens ,
1101
+ grammar.trigger_words ,
1058
1102
};
1059
1103
1060
1104
// redirect elements in stacks to point to new rules
@@ -1076,6 +1120,10 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
1076
1120
void llama_grammar_apply_impl (const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
1077
1121
GGML_ASSERT (grammar.vocab != nullptr );
1078
1122
1123
+ if (grammar.awaiting_trigger ) {
1124
+ return ;
1125
+ }
1126
+
1079
1127
bool allow_eog = false ;
1080
1128
for (const auto & stack : grammar.stacks ) {
1081
1129
if (stack.empty ()) {
@@ -1115,6 +1163,34 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
1115
1163
void llama_grammar_accept_impl (struct llama_grammar & grammar, llama_token token) {
1116
1164
GGML_ASSERT (grammar.vocab != nullptr );
1117
1165
1166
+ const auto & piece = grammar.vocab ->token_to_piece (token);
1167
+
1168
+ if (grammar.awaiting_trigger ) {
1169
+ if (std::find (grammar.trigger_tokens .begin (), grammar.trigger_tokens .end (), token) != grammar.trigger_tokens .end ()) {
1170
+ grammar.awaiting_trigger = false ;
1171
+ grammar.trigger_buffer .clear ();
1172
+ llama_grammar_accept_str (grammar, piece);
1173
+ LLAMA_LOG_DEBUG (" Grammar triggered on token %u (`%s`)" , token, piece.c_str ());
1174
+ return ;
1175
+ } else {
1176
+ // TODO: consider a smarter incremental substring search algorithm (store last position to search from).
1177
+ grammar.trigger_buffer += piece;
1178
+ for (const auto & word : grammar.trigger_words ) {
1179
+ auto pos = grammar.trigger_buffer .find (word);
1180
+ if (pos != std::string::npos) {
1181
+ grammar.awaiting_trigger = false ;
1182
+ auto constrained_str = grammar.trigger_buffer .substr (pos);
1183
+ grammar.trigger_buffer .clear ();
1184
+ llama_grammar_accept_str (grammar, constrained_str);
1185
+ LLAMA_LOG_DEBUG (" Grammar triggered on word `%s`" , word.c_str ());
1186
+ return ;
1187
+ }
1188
+ }
1189
+ LLAMA_LOG_DEBUG (" Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n " , token, piece.c_str (), grammar.trigger_buffer .c_str ());
1190
+ return ;
1191
+ }
1192
+ }
1193
+
1118
1194
if (grammar.vocab ->is_eog (token)) {
1119
1195
for (const auto & stack : grammar.stacks ) {
1120
1196
if (stack.empty ()) {
@@ -1124,8 +1200,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
1124
1200
GGML_ABORT (" fatal error" );
1125
1201
}
1126
1202
1127
- const std::string & piece = grammar.vocab ->token_to_piece (token);
1203
+ llama_grammar_accept_str (grammar, piece);
1204
+ }
1128
1205
1206
+ void llama_grammar_accept_str (struct llama_grammar & grammar, const std::string & piece) {
1129
1207
// Note terminating 0 in decoded string
1130
1208
const auto decoded = decode_utf8 (piece, grammar.partial_utf8 );
1131
1209
const auto & code_points = decoded.first ;
@@ -1135,5 +1213,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
1135
1213
}
1136
1214
1137
1215
grammar.partial_utf8 = decoded.second ;
1138
- GGML_ASSERT (!grammar.stacks .empty ());
1216
+ if (grammar.stacks .empty ()) {
1217
+ throw std::runtime_error (" Unexpected empty grammar stack after accepting piece: " + piece);
1218
+ }
1139
1219
}
0 commit comments