@@ -2162,7 +2162,7 @@ struct llama_vocab {
21622162 std::unordered_map<token, id> token_to_id;
21632163 std::vector<token_data> id_to_token;
21642164
2165- std::unordered_map<token, id> special_tokens_cache;
2165+ std::vector< id> special_tokens_cache;
21662166
21672167 std::map<std::pair<std::string, std::string>, int> bpe_ranks;
21682168
@@ -4831,97 +4831,19 @@ static void llm_load_vocab(
48314831
48324832 // build special tokens cache
48334833 {
4834- // TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type,
4835- // and will always be correctly labeled in 'added_tokens.json' etc.
4836- // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed
4837- // to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer
4838- // are special tokens.
4839- // From testing, this appears to correlate 1:1 with special tokens.
4840- //
4841-
4842- // Counting special tokens and verifying in only one direction
4843- // is sufficient to detect difference in those two sets.
4844- //
4845- uint32_t special_tokens_count_by_type = 0;
4846- uint32_t special_tokens_count_from_verification = 0;
4847-
4848- bool special_tokens_definition_mismatch = false;
4849-
4850- for (const auto & t : vocab.token_to_id) {
4851- const auto & token = t.first;
4852- const auto & id = t.second;
4853-
4854- // Count all non-normal tokens in the vocab while iterating
4834+ for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
48554835 if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
4856- special_tokens_count_by_type++ ;
4836+ vocab.special_tokens_cache.push_back(id) ;
48574837 }
4838+ }
48584839
4859- // Skip single character tokens
4860- if (token.length() > 1) {
4861- bool is_tokenizable = false;
4862-
4863- // Split token string representation in two, in all possible ways
4864- // and check if both halves can be matched to a valid token
4865- for (unsigned i = 1; i < token.length();) {
4866- const auto left = token.substr(0, i);
4867- const auto right = token.substr(i);
4868-
4869- // check if we didnt partition in the middle of a utf sequence
4870- auto utf = utf8_len(left.at(left.length() - 1));
4871-
4872- if (utf == 1) {
4873- if (vocab.token_to_id.find(left) != vocab.token_to_id.end() &&
4874- vocab.token_to_id.find(right) != vocab.token_to_id.end() ) {
4875- is_tokenizable = true;
4876- break;
4877- }
4878- i++;
4879- } else {
4880- // skip over the rest of multibyte utf sequence
4881- i += utf - 1;
4882- }
4883- }
4884-
4885- if (!is_tokenizable) {
4886- // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1
4887- // it's faster to re-filter them here, since there are way less candidates now
4888-
4889- // Calculate a total "utf" length of a token string representation
4890- size_t utf8_str_len = 0;
4891- for (unsigned i = 0; i < token.length();) {
4892- utf8_str_len++;
4893- i += utf8_len(token.at(i));
4894- }
4895-
4896- // And skip the ones which are one character
4897- if (utf8_str_len > 1) {
4898- // At this point what we have left are special tokens only
4899- vocab.special_tokens_cache[token] = id;
4900-
4901- // Count manually found special tokens
4902- special_tokens_count_from_verification++;
4903-
4904- // If this manually found special token is not marked as such, flag a mismatch
4905- if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) {
4906- special_tokens_definition_mismatch = true;
4907- }
4908- }
4909- }
4840+ std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(),
4841+ [&] (const llama_vocab::id a, const llama_vocab::id b) {
4842+ return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
49104843 }
4911- }
4844+ );
49124845
4913- if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) {
4914- LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
4915- __func__,
4916- special_tokens_count_from_verification, vocab.id_to_token.size(),
4917- special_tokens_count_by_type, vocab.id_to_token.size()
4918- );
4919- } else {
4920- LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n",
4921- __func__,
4922- special_tokens_count_from_verification, vocab.id_to_token.size()
4923- );
4924- }
4846+ LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size());
49254847 }
49264848}
49274849
@@ -13146,7 +13068,7 @@ struct llm_tokenizer_wpm {
1314613068 llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
1314713069
1314813070 void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
13149- auto * token_map = & vocab.token_to_id;
13071+ const auto & token_map = vocab.token_to_id;
1315013072
1315113073 // normalize and split by whitespace
1315213074 std::vector<std::string> words = preprocess(text);
@@ -13161,108 +13083,89 @@ struct llm_tokenizer_wpm {
1316113083 }
1316213084
1316313085 // prepend phantom space
13164- std::string word1 = "\xe2\x96\x81" + word;
13165- int n = word1.size();
13086+ const std::string word1 = "\xe2\x96\x81" + word;
13087+ const int n = word1.size();
1316613088
13167- // we're at the start of a new word
13168- int i = 0;
13169- bool match_any = false;
13089+ const size_t current_tokens = output.size();
1317013090
13091+ // we're at the start of a new word
1317113092 // move through character position in word
13172- while (i < n) {
13093+ for (int i = 0; i < n; ++i ) {
1317313094 // loop through possible match length
1317413095 bool match = false;
1317513096 for (int j = n; j > i; j--) {
13176- auto it = token_map-> find(word1.substr(i, j - i));
13177- if (it != token_map-> end()) {
13097+ auto it = token_map. find(word1.substr(i, j - i));
13098+ if (it != token_map. end()) {
1317813099 output.push_back(it->second);
1317913100 match = true;
13180- match_any = true;
13181- i = j;
13101+ i = j - 1;
1318213102 break;
1318313103 }
1318413104 }
1318513105
13186- // must be an unknown character
13187- if (!match) {
13188- i++;
13106+ if (!match) { // discard all
13107+ output.resize(current_tokens);
13108+ break; // and discard next tokens
1318913109 }
1319013110 }
1319113111
1319213112 // we didn't find any matches for this word
13193- if (!match_any ) {
13113+ if (current_tokens == output.size() ) {
1319413114 output.push_back(vocab.special_unk_id);
1319513115 }
1319613116 }
1319713117 }
1319813118
1319913119 std::vector<std::string> preprocess(const std::string & text) {
13200- std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
13201-
13202- // strip accents, strip control, uniformize whitespace,
13203- // to lowercase, pad chinese characters, pad punctuation
13204- std::string new_str = "";
13205- for (uint32_t code : cpts_nfd) {
13206- const codepoint_flags flags = unicode_cpt_flags(code);
13207- if (flags.is_accent_mark || flags.is_control) {
13120+ const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
13121+ std::vector<std::string> words(1, "");
13122+
13123+ for (const char32_t cpt : cpts_nfd) {
13124+ const auto flags = unicode_cpt_flags(cpt);
13125+
13126+ if (flags.is_whitespace) {
13127+ if (words.back().size()) { // finish previous word if any
13128+ words.emplace_back();
13129+ }
1320813130 continue;
1320913131 }
13210- code = unicode_tolower(code);
13211- if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ?
13212- code = ' ';
13213- }
13214- std::string s = unicode_cpt_to_utf8(code);
13215- if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) {
13216- new_str += " ";
13217- new_str += s;
13218- new_str += " ";
13219- } else {
13220- new_str += s;
13132+
13133+ assert (!flags.is_separator);
13134+ if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
13135+ continue;
1322113136 }
13222- }
1322313137
13224- // split by whitespace
13225- uint64_t l = 0;
13226- uint64_t r = 0;
13227- std::vector<std::string> words;
13228- while (r < new_str.size()) {
13229- // if is whitespace
13230- if (isspace(new_str[r], std::locale::classic())) {
13231- if (r > l) words.push_back(new_str.substr(l, (r - l)));
13232- l = r + 1;
13233- r = l;
13138+ const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
13139+ if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
13140+ if (words.back().size()) { // finish previous word if any
13141+ words.emplace_back();
13142+ }
13143+ words.back() = s; // single char word
13144+ words.emplace_back(); // start a new word
1323413145 } else {
13235- r += 1;
13146+ words.back() += s; // append char to word
1323613147 }
1323713148 }
13238- if (r > l) {
13239- words.push_back(new_str.substr(l, (r - l)));
13240- }
13241- return words;
13242- }
1324313149
13244- bool is_ascii_punct(uint32_t code) {
13245- if (code > 0xFF) {
13246- return false;
13150+ if (!words.back().size()) {
13151+ words.pop_back();
1324713152 }
13248- auto c = char(static_cast<unsigned char>(code));
13249- return ispunct(c, std::locale::classic()) ;
13153+
13154+ return words ;
1325013155 }
1325113156
13252- bool is_chinese_char(uint32_t cpt) {
13253- if ((cpt >= 0x4E00 && cpt <= 0x9FFF) ||
13254- (cpt >= 0x3400 && cpt <= 0x4DBF) ||
13157+ static bool is_chinese_char(uint32_t cpt) {
13158+ return
13159+ (cpt >= 0x04E00 && cpt <= 0x09FFF) ||
13160+ (cpt >= 0x03400 && cpt <= 0x04DBF) ||
1325513161 (cpt >= 0x20000 && cpt <= 0x2A6DF) ||
1325613162 (cpt >= 0x2A700 && cpt <= 0x2B73F) ||
1325713163 (cpt >= 0x2B740 && cpt <= 0x2B81F) ||
1325813164 (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
13259- (cpt >= 0xF900 && cpt <= 0xFAFF) ||
13260- (cpt >= 0x2F800 && cpt <= 0x2FA1F) ||
13261- (cpt >= 0x3000 && cpt <= 0x303F) ||
13262- (cpt >= 0xFF00 && cpt <= 0xFFEF)) {
13263- return true; // NOLINT
13264- }
13265- return false;
13165+ (cpt >= 0x0F900 && cpt <= 0x0FAFF) ||
13166+ (cpt >= 0x2F800 && cpt <= 0x2FA1F);
13167+ //(cpt >= 0x3000 && cpt <= 0x303F) ||
13168+ //(cpt >= 0xFF00 && cpt <= 0xFFEF);
1326613169 }
1326713170
1326813171 const llama_vocab & vocab;
@@ -13306,9 +13209,8 @@ struct fragment_buffer_variant {
1330613209
1330713210static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
1330813211 // for each special token
13309- for (const auto & st: vocab.special_tokens_cache) {
13310- const auto & special_token = st.first;
13311- const auto & special_id = st.second;
13212+ for (const llama_vocab::id special_id : vocab.special_tokens_cache) {
13213+ const auto & special_token = vocab.id_to_token[special_id].text;
1331213214
1331313215 // for each text fragment
1331413216 std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
@@ -13317,7 +13219,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1331713219
1331813220 // if a fragment is text ( not yet processed )
1331913221 if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
13320- auto * raw_text = &( fragment.raw_text) ;
13222+ auto & raw_text = fragment.raw_text;
1332113223
1332213224 auto raw_text_base_offset = fragment.offset;
1332313225 auto raw_text_base_length = fragment.length;
@@ -13327,7 +13229,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1332713229 // find the first occurrence of a given special token in this fragment
1332813230 // passing offset argument only limit the "search area" but match coordinates
1332913231 // are still relative to the source full raw_text
13330- auto match = raw_text-> find(special_token, raw_text_base_offset);
13232+ auto match = raw_text. find(special_token, raw_text_base_offset);
1333113233
1333213234 // no occurrences found, stop processing this fragment for a given special token
1333313235 if (match == std::string::npos) break;
@@ -13346,7 +13248,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1334613248 // left
1334713249 const int64_t left_reminder_offset = raw_text_base_offset + 0;
1334813250 const int64_t left_reminder_length = match - raw_text_base_offset;
13349- buffer.emplace_after(it, (* raw_text) , left_reminder_offset, left_reminder_length);
13251+ buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
1335013252
1335113253#ifdef PRETOKENIZERDEBUG
1335213254 LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
@@ -13362,7 +13264,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1336213264 if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
1336313265 const int64_t right_reminder_offset = match + special_token.length();
1336413266 const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
13365- buffer.emplace_after(it, (* raw_text) , right_reminder_offset, right_reminder_length);
13267+ buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
1336613268
1336713269#ifdef PRETOKENIZERDEBUG
1336813270 LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
0 commit comments