Skip to content

Commit 54dde5e

Browse files
authored
Add memoized cache to llama_grammar_reject_candidates_for_stack (LostRuins#1615)
* Add memoized cache to llama_grammar_reject_candidates_for_stack * make size cutoff more aggressive and move to outer branch * update comment * add cache reset whenever grammar is reloaded * remove explicit reference types for compiler transportability
1 parent b884a7f commit 54dde5e

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

gpttype_adapter.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,6 +1773,7 @@ static void load_grammar(const std::string & gammarstr)
17731773
{
17741774
if(grammar!=nullptr) //on demand free when next grammar is loaded
17751775
{
1776+
llama_grammar_reset_memos();
17761777
llama_grammar_free_impl(grammar);
17771778
grammar = nullptr;
17781779
}

src/llama-grammar.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,34 @@
88
#include <algorithm>
99
#include <stdexcept>
1010

11+
#include <iostream>
12+
#include <string_view>
13+
#include <unordered_set>
14+
1115
//
1216
// helpers
1317
//
1418

19+
using bytes = std::pair<const char*, size_t>;
20+
using hash_entry_size = std::pair<size_t, size_t>;
21+
22+
template <>
23+
struct std::hash<bytes>
24+
{
25+
std::size_t operator()(const bytes& x) const noexcept
26+
{
27+
return std::hash<std::string_view>{}({x.first, x.second});
28+
}
29+
};
30+
31+
using candidates_memos = std::unordered_map<size_t, llama_grammar_candidates>;
32+
using stack_memos = std::unordered_map<size_t, candidates_memos>;
33+
static stack_memos memo_cache;
34+
35+
static void llama_grammar_reset_memos() {
36+
memo_cache.clear();
37+
}
38+
1539
// NOTE: assumes valid utf8 (but checks for overrun)
1640
static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
1741
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
@@ -864,6 +888,38 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
864888
}
865889
return rejects;
866890
}
891+
892+
auto stack_hash_start = reinterpret_cast<const char *>(stack.data());
893+
auto stack_hash_size = sizeof(stack[0]) * stack.size();
894+
auto stack_hash = std::hash<bytes>{}({ stack_hash_start, stack_hash_size });
895+
896+
llama_grammar_candidates * cache_target = nullptr;
897+
898+
// Tests show that >75% of candidate lists are under 1280 and 50% are under 640b.
899+
// Most 'problem' loops are under 24b. However, candidate lists can be over 72k,
900+
// so we need to limit our checks.
901+
902+
// We'll only attempt to memoize candidate lists under 80b
903+
// Doing an over-aggressive size cutoff first befor any other processing 'saves' easy cases
904+
// extra processing but still rescues 'hard' cases from slow down or hangs.
905+
// This leads to a speed up of both easy and hard cases.
906+
const size_t hash_cutoff = 80;
907+
auto candidates_hash_size = sizeof(candidates[0]) * candidates.size();
908+
if (candidates_hash_size < hash_cutoff) {
909+
// Only check stash hash first - these are usually ~24b, and almost always under 64b
910+
if (auto cache_hit = memo_cache.find(stack_hash); cache_hit != memo_cache.end()) {
911+
auto & candidates_memos = cache_hit->second;
912+
auto candidates_hash_start = reinterpret_cast<const char *>(candidates.data());
913+
auto candidates_hash = std::hash<bytes>{}({ candidates_hash_start, candidates_hash_size });
914+
if (auto cache_hit2 = candidates_memos.find(candidates_hash); cache_hit2 != candidates_memos.end()) {
915+
return cache_hit2->second;
916+
} else {
917+
cache_target = &(candidates_memos[candidates_hash]);
918+
}
919+
} else {
920+
memo_cache[stack_hash];
921+
}
922+
}
867923

868924
const llama_grammar_element * stack_pos = stack.back();
869925

@@ -900,6 +956,9 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
900956
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
901957
}
902958

959+
if (cache_target) {
960+
*cache_target = rejects;
961+
}
903962
return rejects;
904963
}
905964

0 commit comments

Comments
 (0)