Skip to content

Commit 3dbdbca

Browse files
committed
Hybrid tokenization of chat prompts based on the slot caches.
1 parent 36c258e commit 3dbdbca

File tree

1 file changed

+60
-1
lines changed

1 file changed

+60
-1
lines changed

examples/server/server.cpp

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "common.h"
55
#include "json-schema-to-grammar.h"
66
#include "llama.h"
7+
#include "llama-vocab.h"
78
#include "log.h"
89
#include "sampling.h"
910
#include "speculative.h"
@@ -3841,7 +3842,65 @@ int main(int argc, char ** argv) {
38413842
// TODO: this log can become very long, put it behind a flag or think about a more compact format
38423843
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
38433844

3844-
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
3845+
std::vector<llama_tokens> tokenized_prompts; // start of new tokenization code based on caches; it may need optimizations and bug fixes
3846+
if (prompt.is_string()) { // attempt tokenization based on the slot token caches first, only for prompts consisting of a single string
3847+
llama_tokens cache_based_tokenization;
3848+
std::string prompt_string = prompt.get<std::string>();
3849+
size_t max_prompt_match_in_chars = 0;
3850+
3851+
SRV_DBG("Attempting slot cache based tokenization of the prompt, total prompt length %lu characters.\n", prompt_string.size());
3852+
for (size_t slot_index = 0; slot_index < ctx_server.slots.size(); slot_index++) {
3853+
size_t prompt_index = 0;
3854+
size_t cache_index = 0;
3855+
llama_tokens partially_tokenized_prompt;
3856+
llama_tokens cache_tokens = ctx_server.slots[slot_index].cache_tokens; // accessing the caches like this might be unsafe
3857+
3858+
if (cache_tokens.size() > 0) {
3859+
SRV_DBG("Slot %ld has %lu cached tokens, attempting prompt tokenization based on them.\n", slot_index, cache_tokens.size());
3860+
for (cache_index = 0; cache_index < cache_tokens.size() && prompt_index < prompt_string.size(); cache_index++) {
3861+
llama_token token = cache_tokens[cache_index];
3862+
const std::string token_string = common_token_to_piece(ctx_server.vocab, token, true);
3863+
size_t token_size = token_string.size();
3864+
3865+
if (prompt_index + token_size <= prompt_string.size() && prompt_string.compare(prompt_index, token_size, token_string) == 0) {
3866+
prompt_index += token_size;
3867+
partially_tokenized_prompt.push_back(token);
3868+
} else if (cache_index == 0) { // the first token from the cache doesn't have to be in the prompt, as it might be a BOS token, so just add it. This might cause issues.
3869+
partially_tokenized_prompt.push_back(token);
3870+
} else {
3871+
break;
3872+
}
3873+
}
3874+
3875+
if (prompt_index > max_prompt_match_in_chars) { // the tokenization based on this slot matches more characters than the previous best match
3876+
max_prompt_match_in_chars = prompt_index;
3877+
cache_based_tokenization = partially_tokenized_prompt;
3878+
}
3879+
}
3880+
}
3881+
3882+
if (max_prompt_match_in_chars > 0) { // if some of the prompt was tokenized based on the slot caches
3883+
std::string remaining_string = prompt_string.substr(max_prompt_match_in_chars);
3884+
std::vector<llama_token> remaining_prompt_tokens = common_tokenize(ctx_server.vocab, remaining_string, true, true); // tokenize the rest of the prompt normally
3885+
3886+
SRV_DBG("The slot caches based tokenization has produced %lu tokens and the regular tokenization an additional %lu tokens for a total of %lu.\n",
3887+
cache_based_tokenization.size(), remaining_prompt_tokens.size(), cache_based_tokenization.size() + remaining_prompt_tokens.size());
3888+
3889+
// concatenate the additional tokens to the cached tokens, but skip the additinal BOS, as we don't need one in the middle of the tokens. This might cause issues.
3890+
if (remaining_prompt_tokens.size() > 1) {
3891+
cache_based_tokenization.insert(cache_based_tokenization.end(), remaining_prompt_tokens.begin() + 1, remaining_prompt_tokens.end());
3892+
}
3893+
3894+
tokenized_prompts.push_back(cache_based_tokenization);
3895+
} else {
3896+
SRV_DBG("Partial tokenization of the %lu character long prompt based on slot caches was not possible.\n", prompt_string.size());
3897+
}
3898+
}
3899+
3900+
if (tokenized_prompts.empty()) { // if the slot token cache based tokenization was not possible, tokenize the prompt normally
3901+
tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
3902+
} // end of new tokenization code based on caches
3903+
38453904
tasks.reserve(tokenized_prompts.size());
38463905
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
38473906
server_task task = server_task(type);

0 commit comments

Comments
 (0)