@@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) {
560560 }
561561 }
562562 } catch (const std::exception & err) {
563- fprintf (stderr, " %s: error parsing grammar: %s\n " , __func__, err.what ());
563+ fprintf (stderr, " %s: error parsing grammar: %s\n\n %s \n " , __func__, err.what (), src );
564564 rules.clear ();
565565 return false ;
566566 }
@@ -960,10 +960,28 @@ struct llama_grammar * llama_grammar_init_impl(
960960 // Important: vec_rules has to be moved here, not copied, because stacks contains
961961 // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
962962 // then the pointers would be invalidated when the local vec_rules goes out of scope.
963- return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), {}, };
963+ return new llama_grammar {
964+ vocab,
965+ std::move (vec_rules),
966+ std::move (stacks),
967+ /* .partial_utf8 = */ {},
968+ /* .lazy =*/ false ,
969+ /* .awaiting_trigger = */ false ,
970+ /* .trigger_buffer = */ " " ,
971+ /* .trigger_tokens = */ {},
972+ /* .trigger_words = */ {},
973+ };
964974}
965975
966- struct llama_grammar * llama_grammar_init_impl (const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
976+ struct llama_grammar * llama_grammar_init_impl (
977+ const struct llama_vocab * vocab,
978+ const char * grammar_str,
979+ const char * grammar_root,
980+ bool lazy,
981+ const char ** trigger_words,
982+ size_t num_trigger_words,
983+ const llama_token * trigger_tokens,
984+ size_t num_trigger_tokens) {
967985 llama_grammar_parser parser;
968986
969987 // if there is a grammar, parse it
@@ -1035,10 +1053,31 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
10351053 }
10361054 } while (true );
10371055
1056+ std::vector<llama_token> vec_trigger_tokens;
1057+ std::vector<std::string> vec_trigger_words;
1058+ for (size_t i = 0 ; i < num_trigger_tokens; i++) {
1059+ GGML_ASSERT (trigger_tokens != nullptr );
1060+ vec_trigger_tokens.push_back (trigger_tokens[i]);
1061+ }
1062+ for (size_t i = 0 ; i < num_trigger_words; i++) {
1063+ GGML_ASSERT (trigger_words != nullptr );
1064+ vec_trigger_words.push_back (trigger_words[i]);
1065+ }
1066+
10381067 // Important: vec_rules has to be moved here, not copied, because stacks contains
10391068 // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
10401069 // then the pointers would be invalidated when the local vec_rules goes out of scope.
1041- return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), {}, };
1070+ return new llama_grammar {
1071+ vocab,
1072+ std::move (vec_rules),
1073+ std::move (stacks),
1074+ /* .partial_utf8 = */ {},
1075+ /* .lazy = */ lazy,
1076+ /* .awaiting_trigger = */ lazy,
1077+ /* .trigger_buffer = */ " " ,
1078+ std::move (vec_trigger_tokens),
1079+ std::move (vec_trigger_words),
1080+ };
10421081}
10431082
10441083void llama_grammar_free_impl (struct llama_grammar * grammar) {
@@ -1055,6 +1094,11 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
10551094 grammar.rules ,
10561095 grammar.stacks ,
10571096 grammar.partial_utf8 ,
1097+ grammar.lazy ,
1098+ grammar.awaiting_trigger ,
1099+ grammar.trigger_buffer ,
1100+ grammar.trigger_tokens ,
1101+ grammar.trigger_words ,
10581102 };
10591103
10601104 // redirect elements in stacks to point to new rules
@@ -1076,6 +1120,10 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
10761120void llama_grammar_apply_impl (const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
10771121 GGML_ASSERT (grammar.vocab != nullptr );
10781122
1123+ if (grammar.awaiting_trigger ) {
1124+ return ;
1125+ }
1126+
10791127 bool allow_eog = false ;
10801128 for (const auto & stack : grammar.stacks ) {
10811129 if (stack.empty ()) {
@@ -1115,6 +1163,34 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
11151163void llama_grammar_accept_impl (struct llama_grammar & grammar, llama_token token) {
11161164 GGML_ASSERT (grammar.vocab != nullptr );
11171165
1166+ const auto & piece = grammar.vocab ->token_to_piece (token);
1167+
1168+ if (grammar.awaiting_trigger ) {
1169+ if (std::find (grammar.trigger_tokens .begin (), grammar.trigger_tokens .end (), token) != grammar.trigger_tokens .end ()) {
1170+ grammar.awaiting_trigger = false ;
1171+ grammar.trigger_buffer .clear ();
1172+ llama_grammar_accept_str (grammar, piece);
1173+ LLAMA_LOG_DEBUG (" Grammar triggered on token %u (`%s`)" , token, piece.c_str ());
1174+ return ;
1175+ } else {
1176+ // TODO: consider a smarter incremental substring search algorithm (store last position to search from).
1177+ grammar.trigger_buffer += piece;
1178+ for (const auto & word : grammar.trigger_words ) {
1179+ auto pos = grammar.trigger_buffer .find (word);
1180+ if (pos != std::string::npos) {
1181+ grammar.awaiting_trigger = false ;
1182+ auto constrained_str = grammar.trigger_buffer .substr (pos);
1183+ grammar.trigger_buffer .clear ();
1184+ llama_grammar_accept_str (grammar, constrained_str);
1185+ LLAMA_LOG_DEBUG (" Grammar triggered on word `%s`" , word.c_str ());
1186+ return ;
1187+ }
1188+ }
1189+ LLAMA_LOG_DEBUG (" Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n " , token, piece.c_str (), grammar.trigger_buffer .c_str ());
1190+ return ;
1191+ }
1192+ }
1193+
11181194 if (grammar.vocab ->is_eog (token)) {
11191195 for (const auto & stack : grammar.stacks ) {
11201196 if (stack.empty ()) {
@@ -1124,8 +1200,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
11241200 GGML_ABORT (" fatal error" );
11251201 }
11261202
1127- const std::string & piece = grammar.vocab ->token_to_piece (token);
1203+ llama_grammar_accept_str (grammar, piece);
1204+ }
11281205
1206+ void llama_grammar_accept_str (struct llama_grammar & grammar, const std::string & piece) {
11291207 // Note terminating 0 in decoded string
11301208 const auto decoded = decode_utf8 (piece, grammar.partial_utf8 );
11311209 const auto & code_points = decoded.first ;
@@ -1135,5 +1213,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
11351213 }
11361214
11371215 grammar.partial_utf8 = decoded.second ;
1138- GGML_ASSERT (!grammar.stacks .empty ());
1216+ if (grammar.stacks .empty ()) {
1217+ throw std::runtime_error (" Unexpected empty grammar stack after accepting piece: " + piece);
1218+ }
11391219}
0 commit comments