@@ -687,31 +687,17 @@ static bool llama_grammar_match_partial_char(
687687// additionally memorizes the stack to its possible stacks by mapping
688688// < llama_grammar_stack, llama_grammar_stacks >
689689
690- struct VectorPointerHash {
691- size_t operator ()(const llama_grammar_stack & v) const {
692- size_t seed = v.size ();
693- for (const auto * ptr : v) {
694- seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6 ) + (seed >> 2 );
695- }
696- return seed;
697- }
698- };
699-
700- static std::unordered_map<
701- llama_grammar_stack,
702- llama_grammar_stacks,
703- VectorPointerHash>
704- llama_grammar_stacks_cache = {};
705-
706690static void llama_grammar_advance_stack_memo (
707691 const llama_grammar_rules & rules,
708692 const llama_grammar_stack & stack,
709- llama_grammar_stacks & new_stacks);
693+ llama_grammar_stacks & new_stacks,
694+ llama_grammar_stacks_cache & stacks_cache);
710695
711696static void llama_grammar_advance_stack_memo_impl (
712697 const llama_grammar_rules & rules,
713698 const llama_grammar_stack & stack,
714- llama_grammar_stacks & new_stacks) {
699+ llama_grammar_stacks & new_stacks,
700+ llama_grammar_stacks_cache & stacks_cache) {
715701 if (stack.empty ()) {
716702 if (std::find (new_stacks.begin (), new_stacks.end (), stack) == new_stacks.end ()) {
717703 new_stacks.emplace_back (stack);
@@ -736,7 +722,7 @@ static void llama_grammar_advance_stack_memo_impl(
736722 // if alternate is nonempty, add to stack
737723 new_stack.push_back (subpos);
738724 }
739- llama_grammar_advance_stack_memo (rules, new_stack, new_stacks);
725+ llama_grammar_advance_stack_memo (rules, new_stack, new_stacks, stacks_cache );
740726 while (!llama_grammar_is_end_of_sequence (subpos)) {
741727 // scan to end of alternate def
742728 subpos++;
@@ -769,17 +755,18 @@ static void llama_grammar_advance_stack_memo_impl(
769755static void llama_grammar_advance_stack_memo (
770756 const llama_grammar_rules & rules,
771757 const llama_grammar_stack & stack,
772- llama_grammar_stacks & new_stacks) {
758+ llama_grammar_stacks & new_stacks,
759+ llama_grammar_stacks_cache & stacks_cache) {
773760
774761 llama_grammar_stacks advanced_stacks;
775762 // Look if stack is already in memory
776- auto it = llama_grammar_stacks_cache .find (stack);
777- if (it != llama_grammar_stacks_cache .end ()) {
763+ auto it = stacks_cache .find (stack);
764+ if (it != stacks_cache .end ()) {
778765 advanced_stacks = it->second ;
779766 } else {
780767 // Advance stacks with memorization
781- llama_grammar_advance_stack_memo_impl (rules, stack, advanced_stacks);
782- llama_grammar_stacks_cache .insert (make_pair (stack, advanced_stacks));
768+ llama_grammar_advance_stack_memo_impl (rules, stack, advanced_stacks, stacks_cache );
769+ stacks_cache .insert (make_pair (stack, advanced_stacks));
783770 }
784771 // Add the advanced stacks to new_stacks avoiding duplicates
785772 for (const auto & new_stack : advanced_stacks) {
@@ -934,7 +921,8 @@ void llama_grammar_accept(
934921 const llama_grammar_rules & rules,
935922 const llama_grammar_stacks & stacks,
936923 const uint32_t chr,
937- llama_grammar_stacks & stacks_new) {
924+ llama_grammar_stacks & stacks_new,
925+ llama_grammar_stacks_cache & stacks_cache) {
938926 stacks_new.clear ();
939927 stacks_new.reserve (stacks.size ());
940928
@@ -952,7 +940,7 @@ void llama_grammar_accept(
952940 if (!llama_grammar_is_end_of_sequence (pos)) {
953941 new_stack.push_back (pos);
954942 }
955- llama_grammar_advance_stack_memo (rules, new_stack, stacks_new);
943+ llama_grammar_advance_stack_memo (rules, new_stack, stacks_new, stacks_cache );
956944 }
957945 }
958946}
@@ -1019,8 +1007,6 @@ struct llama_grammar * llama_grammar_init_impl(
10191007 const llama_grammar_element ** rules,
10201008 size_t n_rules,
10211009 size_t start_rule_index) {
1022- // Clear stacks cache
1023- llama_grammar_stacks_cache.clear ();
10241010 const llama_grammar_element * pos;
10251011
10261012 // copy rule definitions into vectors
@@ -1048,14 +1034,15 @@ struct llama_grammar * llama_grammar_init_impl(
10481034
10491035 // loop over alternates of start rule to build initial stacks
10501036 llama_grammar_stacks stacks;
1037+ llama_grammar_stacks_cache stacks_cache;
10511038 pos = vec_rules[start_rule_index].data ();
10521039 do {
10531040 llama_grammar_stack stack;
10541041 if (!llama_grammar_is_end_of_sequence (pos)) {
10551042 // if alternate is nonempty, add to stack
10561043 stack.push_back (pos);
10571044 }
1058- llama_grammar_advance_stack_memo (vec_rules, stack, stacks);
1045+ llama_grammar_advance_stack_memo (vec_rules, stack, stacks, stacks_cache );
10591046 while (!llama_grammar_is_end_of_sequence (pos)) {
10601047 // scan to end of alternate def
10611048 pos++;
@@ -1075,8 +1062,6 @@ struct llama_grammar * llama_grammar_init_impl(
10751062}
10761063
10771064struct llama_grammar * llama_grammar_init_impl (const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
1078- // Clear stacks cache
1079- llama_grammar_stacks_cache.clear ();
10801065 llama_grammar_parser parser;
10811066
10821067 // if there is a grammar, parse it
@@ -1128,14 +1113,15 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
11281113
11291114 // loop over alternates of start rule to build initial stacks
11301115 llama_grammar_stacks stacks;
1116+ llama_grammar_stacks_cache stacks_cache;
11311117 pos = vec_rules[start_rule_index].data ();
11321118 do {
11331119 llama_grammar_stack stack;
11341120 if (!llama_grammar_is_end_of_sequence (pos)) {
11351121 // if alternate is nonempty, add to stack
11361122 stack.push_back (pos);
11371123 }
1138- llama_grammar_advance_stack_memo (vec_rules, stack, stacks);
1124+ llama_grammar_advance_stack_memo (vec_rules, stack, stacks, stacks_cache );
11391125 while (!llama_grammar_is_end_of_sequence (pos)) {
11401126 // scan to end of alternate def
11411127 pos++;
@@ -1239,9 +1225,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
12391225 const auto & code_points = decoded.first ;
12401226
12411227 llama_grammar_stacks stacks_new;
1228+ llama_grammar_stacks_cache stacks_cache;
12421229
12431230 for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
1244- llama_grammar_accept (grammar.rules , grammar.stacks , *it, stacks_new);
1231+ llama_grammar_accept (grammar.rules , grammar.stacks , *it, stacks_new, stacks_cache );
12451232 grammar.stacks = std::move (stacks_new);
12461233 }
12471234
0 commit comments