|
4 | 4 | #include "common.h" |
5 | 5 | #include "json-schema-to-grammar.h" |
6 | 6 | #include "llama.h" |
| 7 | +#include "llama-vocab.h" |
7 | 8 | #include "log.h" |
8 | 9 | #include "sampling.h" |
9 | 10 | #include "speculative.h" |
@@ -3841,7 +3842,65 @@ int main(int argc, char ** argv) { |
3841 | 3842 | // TODO: this log can become very long, put it behind a flag or think about a more compact format |
3842 | 3843 | //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str()); |
3843 | 3844 |
|
3844 | | - std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); |
| 3845 | + std::vector<llama_tokens> tokenized_prompts; // start of new tokenization code based on caches; it may need optimizations and bug fixes |
| 3846 | + if (prompt.is_string()) { // attempt tokenization based on the slot token caches first, only for prompts consisting of a single string |
| 3847 | + llama_tokens cache_based_tokenization; |
| 3848 | + std::string prompt_string = prompt.get<std::string>(); |
| 3849 | + size_t max_prompt_match_in_chars = 0; |
| 3850 | + |
| 3851 | + SRV_DBG("Attempting slot cache based tokenization of the prompt, total prompt length %lu characters.\n", prompt_string.size()); |
| 3852 | + for (size_t slot_index = 0; slot_index < ctx_server.slots.size(); slot_index++) { |
| 3853 | + size_t prompt_index = 0; |
| 3854 | + size_t cache_index = 0; |
| 3855 | + llama_tokens partially_tokenized_prompt; |
| 3856 | + llama_tokens cache_tokens = ctx_server.slots[slot_index].cache_tokens; // accessing the caches like this might be unsafe |
| 3857 | + |
| 3858 | + if (cache_tokens.size() > 0) { |
| 3859 | + SRV_DBG("Slot %ld has %lu cached tokens, attempting prompt tokenization based on them.\n", slot_index, cache_tokens.size()); |
| 3860 | + for (cache_index = 0; cache_index < cache_tokens.size() && prompt_index < prompt_string.size(); cache_index++) { |
| 3861 | + llama_token token = cache_tokens[cache_index]; |
| 3862 | + const std::string token_string = common_token_to_piece(ctx_server.vocab, token, true); |
| 3863 | + size_t token_size = token_string.size(); |
| 3864 | + |
| 3865 | + if (prompt_index + token_size <= prompt_string.size() && prompt_string.compare(prompt_index, token_size, token_string) == 0) { |
| 3866 | + prompt_index += token_size; |
| 3867 | + partially_tokenized_prompt.push_back(token); |
| 3868 | + } else if (cache_index == 0) { // the first token from the cache doesn't have to be in the prompt, as it might be a BOS token, so just add it. This might cause issues. |
| 3869 | + partially_tokenized_prompt.push_back(token); |
| 3870 | + } else { |
| 3871 | + break; |
| 3872 | + } |
| 3873 | + } |
| 3874 | + |
| 3875 | + if (prompt_index > max_prompt_match_in_chars) { // the tokenization based on this slot matches more characters than the previous best match |
| 3876 | + max_prompt_match_in_chars = prompt_index; |
| 3877 | + cache_based_tokenization = partially_tokenized_prompt; |
| 3878 | + } |
| 3879 | + } |
| 3880 | + } |
| 3881 | + |
| 3882 | + if (max_prompt_match_in_chars > 0) { // if some of the prompt was tokenized based on the slot caches |
| 3883 | + std::string remaining_string = prompt_string.substr(max_prompt_match_in_chars); |
| 3884 | + std::vector<llama_token> remaining_prompt_tokens = common_tokenize(ctx_server.vocab, remaining_string, true, true); // tokenize the rest of the prompt normally |
| 3885 | + |
| 3886 | + SRV_DBG("The slot caches based tokenization has produced %lu tokens and the regular tokenization an additional %lu tokens for a total of %lu.\n", |
| 3887 | + cache_based_tokenization.size(), remaining_prompt_tokens.size(), cache_based_tokenization.size() + remaining_prompt_tokens.size()); |
| 3888 | + |
| 3889 | + // concatenate the additional tokens to the cached tokens, but skip the additinal BOS, as we don't need one in the middle of the tokens. This might cause issues. |
| 3890 | + if (remaining_prompt_tokens.size() > 1) { |
| 3891 | + cache_based_tokenization.insert(cache_based_tokenization.end(), remaining_prompt_tokens.begin() + 1, remaining_prompt_tokens.end()); |
| 3892 | + } |
| 3893 | + |
| 3894 | + tokenized_prompts.push_back(cache_based_tokenization); |
| 3895 | + } else { |
| 3896 | + SRV_DBG("Partial tokenization of the %lu character long prompt based on slot caches was not possible.\n", prompt_string.size()); |
| 3897 | + } |
| 3898 | + } |
| 3899 | + |
| 3900 | + if (tokenized_prompts.empty()) { // if the slot token cache based tokenization was not possible, tokenize the prompt normally |
| 3901 | + tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); |
| 3902 | + } // end of new tokenization code based on caches |
| 3903 | + |
3845 | 3904 | tasks.reserve(tokenized_prompts.size()); |
3846 | 3905 | for (size_t i = 0; i < tokenized_prompts.size(); i++) { |
3847 | 3906 | server_task task = server_task(type); |
|
0 commit comments