Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2572,7 +2572,7 @@ struct server_context {
GGML_ASSERT(slot.ga_n == 1);

// reuse any previously computed tokens that are common with the new prompt
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
slot.n_past = common_part(ctx, model, slot.cache_tokens, slot.prompt);

// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
Expand Down
25 changes: 25 additions & 0 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,31 @@ static size_t common_part(const std::string & a, const std::string & b) {
return i;
}

static size_t common_part(const llama_context * ctx, const llama_model * model, const std::vector<llama_token> & a, const std::string & b) {
size_t pos = 0;
size_t token_idx = 0;

for (const auto & token : a) {
std::string piece = llama_token_to_piece(ctx, token);

if (pos + piece.size() <= b.size() && b.compare(pos, piece.size(), piece) == 0) {
pos += piece.size();
token_idx++;
continue;
}

//Below is to handle the auto insert BOS case
if (token_idx == 0 && token == llama_token_bos(model)) {
token_idx++;
continue;
}

return token_idx;
}

return token_idx;
}

static bool ends_with(const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}
Expand Down