|
8 | 8 | #include <algorithm> |
9 | 9 | #include <stdexcept> |
10 | 10 |
|
| 11 | +#include <iostream> |
| 12 | +#include <string_view> |
| 13 | +#include <unordered_set> |
| 14 | + |
11 | 15 | // |
12 | 16 | // helpers |
13 | 17 | // |
14 | 18 |
|
| 19 | +using bytes = std::pair<const char*, size_t>; |
| 20 | +using hash_entry_size = std::pair<size_t, size_t>; |
| 21 | + |
| 22 | +template <> |
| 23 | +struct std::hash<bytes> |
| 24 | +{ |
| 25 | + std::size_t operator()(const bytes& x) const noexcept |
| 26 | + { |
| 27 | + return std::hash<std::string_view>{}({x.first, x.second}); |
| 28 | + } |
| 29 | +}; |
| 30 | + |
| 31 | +using candidates_memos = std::unordered_map<size_t, llama_grammar_candidates>; |
| 32 | +using stack_memos = std::unordered_map<size_t, candidates_memos>; |
| 33 | +static stack_memos memo_cache; |
| 34 | + |
| 35 | +static void llama_grammar_reset_memos() { |
| 36 | + memo_cache.clear(); |
| 37 | +} |
| 38 | + |
15 | 39 | // NOTE: assumes valid utf8 (but checks for overrun) |
16 | 40 | static std::pair<uint32_t, const char *> decode_utf8(const char * src) { |
17 | 41 | static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; |
@@ -864,6 +888,38 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( |
864 | 888 | } |
865 | 889 | return rejects; |
866 | 890 | } |
| 891 | + |
| 892 | + auto stack_hash_start = reinterpret_cast<const char *>(stack.data()); |
| 893 | + auto stack_hash_size = sizeof(stack[0]) * stack.size(); |
| 894 | + auto stack_hash = std::hash<bytes>{}({ stack_hash_start, stack_hash_size }); |
| 895 | + |
| 896 | + llama_grammar_candidates * cache_target = nullptr; |
| 897 | + |
| 898 | + // Tests show that >75% of candidate lists are under 1280 and 50% are under 640b. |
| 899 | + // Most 'problem' loops are under 24b. However, candidate lists can be over 72k, |
| 900 | + // so we need to limit our checks. |
| 901 | + |
| 902 | + // We'll only attempt to memoize candidate lists under 80b |
| 903 | + // Doing an over-aggressive size cutoff first befor any other processing 'saves' easy cases |
| 904 | + // extra processing but still rescues 'hard' cases from slow down or hangs. |
| 905 | + // This leads to a speed up of both easy and hard cases. |
| 906 | + const size_t hash_cutoff = 80; |
| 907 | + auto candidates_hash_size = sizeof(candidates[0]) * candidates.size(); |
| 908 | + if (candidates_hash_size < hash_cutoff) { |
| 909 | + // Only check stash hash first - these are usually ~24b, and almost always under 64b |
| 910 | + if (auto cache_hit = memo_cache.find(stack_hash); cache_hit != memo_cache.end()) { |
| 911 | + auto & candidates_memos = cache_hit->second; |
| 912 | + auto candidates_hash_start = reinterpret_cast<const char *>(candidates.data()); |
| 913 | + auto candidates_hash = std::hash<bytes>{}({ candidates_hash_start, candidates_hash_size }); |
| 914 | + if (auto cache_hit2 = candidates_memos.find(candidates_hash); cache_hit2 != candidates_memos.end()) { |
| 915 | + return cache_hit2->second; |
| 916 | + } else { |
| 917 | + cache_target = &(candidates_memos[candidates_hash]); |
| 918 | + } |
| 919 | + } else { |
| 920 | + memo_cache[stack_hash]; |
| 921 | + } |
| 922 | + } |
867 | 923 |
|
868 | 924 | const llama_grammar_element * stack_pos = stack.back(); |
869 | 925 |
|
@@ -900,6 +956,9 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( |
900 | 956 | rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); |
901 | 957 | } |
902 | 958 |
|
| 959 | + if (cache_target) { |
| 960 | + *cache_target = rejects; |
| 961 | + } |
903 | 962 | return rejects; |
904 | 963 | } |
905 | 964 |
|
|
0 commit comments