Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1570,11 +1570,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
}

bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
return token != -1 && (
token == llama_token_eos_impl(vocab) ||
token == llama_token_eot_impl(vocab) ||
token == llama_token_eom_impl(vocab)
);
return token != -1 && vocab.special_eog_ids.count(token) > 0;
}

bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
Expand Down
14 changes: 9 additions & 5 deletions src/llama-vocab.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <vector>
#include <unordered_map>
#include <map>
#include <set>

struct llama_vocab {
using id = llama_token;
Expand Down Expand Up @@ -49,12 +50,15 @@ struct llama_vocab {
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
id special_eom_id = -1;

// set of all tokens that cause "end of generation"
std::set<id> special_eog_ids;

// tokenizer flags
bool tokenizer_add_space_prefix = false;
bool tokenizer_add_bos = false;
bool tokenizer_add_eos = false;
bool tokenizer_ignore_merges = false;
bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
bool tokenizer_add_space_prefix = false;
bool tokenizer_add_bos = false;
bool tokenizer_add_eos = false;
bool tokenizer_ignore_merges = false;
bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
bool tokenizer_remove_extra_whitespaces = false;
bool tokenizer_escape_whitespaces = true;
bool tokenizer_treat_whitespace_as_suffix = false;
Expand Down
59 changes: 51 additions & 8 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6509,21 +6509,21 @@ static void llm_load_vocab(
// for now, we apply this workaround to find the EOT token based on its text
if (vocab.special_eot_id == -1) {
for (const auto & t : vocab.token_to_id) {
if (
if (false
// TODO: gemma "<end_of_turn>" is exported as a normal token, so the following check does not work
// need to fix convert script
//vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
(t.first == "<|eot_id|>" ||
t.first == "<|im_end|>" ||
t.first == "<|end|>" ||
t.first == "<end_of_turn>" ||
t.first == "<|endoftext|>"
)
|| t.first == "<|eot_id|>"
|| t.first == "<|im_end|>"
|| t.first == "<|end|>"
|| t.first == "<end_of_turn>"
|| t.first == "<|endoftext|>"
|| t.first == "<EOT>"
) {
vocab.special_eot_id = t.second;
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.first.c_str());
__func__, t.first.c_str());
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
}
break;
Expand All @@ -6546,6 +6546,44 @@ static void llm_load_vocab(
}
}
}

// maintain a list of tokens that cause end-of-generation
// this is currently determined based on the token text, which is obviously not ideal
// ref: https://github.com/ggerganov/llama.cpp/issues/9606
vocab.special_eog_ids.clear();
for (const auto & t : vocab.token_to_id) {
if (false
|| t.first == "<|eot_id|>"
|| t.first == "<|im_end|>"
|| t.first == "<|end|>"
|| t.first == "<end_of_turn>"
|| t.first == "<|endoftext|>"
|| t.first == "<|eom_id|>"
|| t.first == "<EOT>"
) {
vocab.special_eog_ids.insert(t.second);
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.first.c_str());
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
}
}
}

if (vocab.special_eos_id != -1 && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
vocab.special_eog_ids.insert(vocab.special_eos_id);
LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
}

if (vocab.special_eot_id != -1 && vocab.special_eog_ids.count(vocab.special_eot_id) == 0) {
vocab.special_eog_ids.insert(vocab.special_eot_id);
LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
}

if (vocab.special_eom_id != -1 && vocab.special_eog_ids.count(vocab.special_eom_id) == 0) {
vocab.special_eog_ids.insert(vocab.special_eom_id);
LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
}
}

// build special tokens cache
Expand Down Expand Up @@ -6749,6 +6787,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); }
if (vocab.special_eom_id != -1) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, vocab.special_eom_id, vocab.id_to_token[vocab.special_eom_id].text.c_str() ); }

for (const auto & id : vocab.special_eog_ids) {
LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, vocab.id_to_token[id].text.c_str() );
}

LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);

Expand Down
Loading