Skip to content

Commit 0306deb

Browse files
committed
run : avoid double tokenization by adopting common_tokenize heuristic
1 parent 5d5c066 commit 0306deb

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

tools/run/run.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -939,17 +939,30 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama
939939
// Function to tokenize the prompt
940940
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
941941
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
942-
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == 0;
943-
944-
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
945-
prompt_tokens.resize(n_prompt_tokens);
946-
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
947-
true) < 0) {
948-
printe("failed to tokenize the prompt\n");
942+
const bool add_special = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == 0;
943+
int n_tokens = prompt.size() + 2 * add_special;
944+
prompt_tokens.resize(n_tokens);
945+
n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.size(),
946+
prompt_tokens.data(), prompt_tokens.size(),
947+
add_special, /*parse_special =*/true);
948+
if (n_tokens == std::numeric_limits<int32_t>::min()) {
949+
printe("tokenization failed: input too large\n");
949950
return -1;
950951
}
951-
952-
return n_prompt_tokens;
952+
if (n_tokens < 0) {
953+
prompt_tokens.resize(-n_tokens);
954+
int check = llama_tokenize(vocab, prompt.c_str(), prompt.size(),
955+
prompt_tokens.data(), prompt_tokens.size(),
956+
add_special, /*parse_special =*/true);
957+
if (check != -n_tokens) {
958+
printe("failed to tokenize the prompt (size mismatch)\n");
959+
return -1;
960+
}
961+
n_tokens = check;
962+
} else {
963+
prompt_tokens.resize(n_tokens);
964+
}
965+
return n_tokens;
953966
}
954967

955968
// Check if we have enough space in the context to evaluate this batch

0 commit comments

Comments
 (0)