@@ -764,7 +764,7 @@ static void llama_grammar_advance_stack_memo(
764764 if (it != stacks_cache.end ()) {
765765 advanced_stacks = it->second ;
766766 } else {
767- // Advance stacks with memorization
767+ // Advance stacks with memorization
768768 llama_grammar_advance_stack_memo_impl (rules, stack, advanced_stacks, stacks_cache);
769769 stacks_cache.insert (make_pair (stack, advanced_stacks));
770770 }
@@ -917,20 +917,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
917917 return grammar->stacks ;
918918}
919919
920- llama_grammar_stacks_cache & llama_grammar_get_stacks_cache (struct llama_grammar * grammar) {
921- return grammar->stacks_cache ;
922- }
923-
924- void llama_grammar_accept (
925- const llama_grammar_rules & rules,
926- const llama_grammar_stacks & stacks,
927- const uint32_t chr,
928- llama_grammar_stacks & stacks_new,
929- llama_grammar_stacks_cache & stacks_cache) {
930- stacks_new.clear ();
931- stacks_new.reserve (stacks.size ());
920+ void llama_grammar_accept (struct llama_grammar * grammar, uint32_t chr) {
921+ llama_grammar_stacks stacks_new;
922+ stacks_new.reserve (grammar->stacks .size ());
932923
933- for (const auto & stack : stacks) {
924+ for (const auto & stack : grammar-> stacks ) {
934925 if (stack.empty ()) {
935926 continue ;
936927 }
@@ -944,9 +935,11 @@ void llama_grammar_accept(
944935 if (!llama_grammar_is_end_of_sequence (pos)) {
945936 new_stack.push_back (pos);
946937 }
947- llama_grammar_advance_stack_memo (rules, new_stack, stacks_new, stacks_cache);
938+ llama_grammar_advance_stack_memo (grammar-> rules , new_stack, stacks_new, grammar-> stacks_cache );
948939 }
949940 }
941+
942+ grammar->stacks = std::move (stacks_new);
950943}
951944
952945llama_grammar_candidates llama_grammar_reject_candidates_for_stack (
@@ -1062,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl(
10621055 // Important: vec_rules has to be moved here, not copied, because stacks contains
10631056 // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
10641057 // then the pointers would be invalidated when the local vec_rules goes out of scope.
1065- return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), {}, std::move (stacks_cache), };
1058+ return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), std::move (stacks_cache), {} , };
10661059}
10671060
10681061struct llama_grammar * llama_grammar_init_impl (const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
@@ -1141,7 +1134,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
11411134 // Important: vec_rules has to be moved here, not copied, because stacks contains
11421135 // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
11431136 // then the pointers would be invalidated when the local vec_rules goes out of scope.
1144- return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), {}, std::move (stacks_cache), };
1137+ return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), std::move (stacks_cache), {} , };
11451138}
11461139
11471140void llama_grammar_free_impl (struct llama_grammar * grammar) {
@@ -1153,15 +1146,21 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
11531146}
11541147
11551148struct llama_grammar * llama_grammar_clone_impl (const struct llama_grammar & grammar) {
1156- llama_grammar * result = new llama_grammar { grammar.vocab , grammar.rules , grammar.stacks , grammar.partial_utf8 , };
1149+ llama_grammar * result = new llama_grammar {
1150+ grammar.vocab ,
1151+ grammar.rules ,
1152+ grammar.stacks ,
1153+ grammar.stacks_cache ,
1154+ grammar.partial_utf8 ,
1155+ };
11571156
11581157 // redirect elements in stacks to point to new rules
11591158 for (size_t is = 0 ; is < result->stacks .size (); is++) {
11601159 for (size_t ie = 0 ; ie < result->stacks [is].size (); ie++) {
11611160 for (size_t ir0 = 0 ; ir0 < grammar.rules .size (); ir0++) {
11621161 for (size_t ir1 = 0 ; ir1 < grammar.rules [ir0].size (); ir1++) {
11631162 if (grammar.stacks [is][ie] == &grammar.rules [ir0][ir1]) {
1164- result->stacks [is][ie] = &result->rules [ir0][ir1];
1163+ result->stacks [is][ie] = &result->rules [ir0][ir1];
11651164 }
11661165 }
11671166 }
@@ -1228,11 +1227,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
12281227 const auto decoded = decode_utf8 (piece, grammar.partial_utf8 );
12291228 const auto & code_points = decoded.first ;
12301229
1231- llama_grammar_stacks stacks_new;
1232-
12331230 for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
1234- llama_grammar_accept (grammar.rules , grammar.stacks , *it, stacks_new, grammar.stacks_cache );
1235- grammar.stacks = std::move (stacks_new);
1231+ llama_grammar_accept (&grammar, *it);
12361232 }
12371233
12381234 grammar.partial_utf8 = decoded.second ;
0 commit comments