2020
2121namespace dlib
2222{
23+ constexpr size_t BPE_TOKENIZER_MAX_TOKEN_LENGTH = 8 ;
24+ constexpr int BPE_TOKENIZER_BASE_VOCAB_SIZE = 256 ;
2325
2426 class bpe_tokenizer
2527 {
2628 public:
27- bpe_tokenizer () : vocab_size(BASE_VOCAB_SIZE )
29+ bpe_tokenizer () : vocab_size(BPE_TOKENIZER_BASE_VOCAB_SIZE )
2830 {
2931 // Initialize the base vocabulary with single bytes
30- for (int i = 0 ; i < BASE_VOCAB_SIZE ; ++i)
32+ for (int i = 0 ; i < BPE_TOKENIZER_BASE_VOCAB_SIZE ; ++i)
3133 vocab[i] = std::vector<uint8_t >{ static_cast <uint8_t >(i) };
3234
3335 // Initialize special tokens with sequential IDs
3436 special_tokens =
3537 {
36- {" <text>" , BASE_VOCAB_SIZE },
37- {" </text>" , BASE_VOCAB_SIZE + 1 },
38- {" <url>" , BASE_VOCAB_SIZE + 2 },
39- {" </url>" , BASE_VOCAB_SIZE + 3 },
40- {" <image>" , BASE_VOCAB_SIZE + 4 },
41- {" </image>" , BASE_VOCAB_SIZE + 5 },
42- {" <video>" , BASE_VOCAB_SIZE + 6 },
43- {" </video>" , BASE_VOCAB_SIZE + 7 },
44- {" <audio>" , BASE_VOCAB_SIZE + 8 },
45- {" </audio>" , BASE_VOCAB_SIZE + 9 },
46- {" <file>" , BASE_VOCAB_SIZE + 10 },
47- {" </file>" , BASE_VOCAB_SIZE + 11 },
48- {" <code>" , BASE_VOCAB_SIZE + 12 },
49- {" </code>" , BASE_VOCAB_SIZE + 13 },
50- {" <summary>" , BASE_VOCAB_SIZE + 14 },
51- {" </summary>" , BASE_VOCAB_SIZE + 15 },
52- {" <think>" , BASE_VOCAB_SIZE + 16 },
53- {" </think>" , BASE_VOCAB_SIZE + 17 },
54- {" <start>" , BASE_VOCAB_SIZE + 18 },
55- {" <end>" , BASE_VOCAB_SIZE + 19 },
56- {" <user>" , BASE_VOCAB_SIZE + 20 },
57- {" <bot>" , BASE_VOCAB_SIZE + 21 },
58- {" <system>" , BASE_VOCAB_SIZE + 22 },
59- {" <question>" , BASE_VOCAB_SIZE + 23 },
60- {" <answer>" , BASE_VOCAB_SIZE + 24 },
61- {" <search>" , BASE_VOCAB_SIZE + 25 },
62- {" <unk>" , BASE_VOCAB_SIZE + 26 },
63- {" <pad>" , BASE_VOCAB_SIZE + 27 }
38+ {" <text>" , BPE_TOKENIZER_BASE_VOCAB_SIZE },
39+ {" </text>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 1 },
40+ {" <url>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 2 },
41+ {" </url>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 3 },
42+ {" <image>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 4 },
43+ {" </image>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 5 },
44+ {" <video>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 6 },
45+ {" </video>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 7 },
46+ {" <audio>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 8 },
47+ {" </audio>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 9 },
48+ {" <file>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 10 },
49+ {" </file>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 11 },
50+ {" <code>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 12 },
51+ {" </code>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 13 },
52+ {" <summary>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 14 },
53+ {" </summary>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 15 },
54+ {" <think>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 16 },
55+ {" </think>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 17 },
56+ {" <start>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 18 },
57+ {" <end>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 19 },
58+ {" <user>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 20 },
59+ {" <bot>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 21 },
60+ {" <system>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 22 },
61+ {" <question>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 23 },
62+ {" <answer>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 24 },
63+ {" <search>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 25 },
64+ {" <unk>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 26 },
65+ {" <pad>" , BPE_TOKENIZER_BASE_VOCAB_SIZE + 27 }
6466 };
6567
6668 // Initialize the vector of special token IDs
@@ -71,9 +73,9 @@ namespace dlib
7173 // Train the tokenizer on the given text
7274 void train (const std::string& text, int vocab_size, bool verbose = false )
7375 {
74- DLIB_CASSERT (vocab_size >= BASE_VOCAB_SIZE );
76+ DLIB_CASSERT (vocab_size >= BPE_TOKENIZER_BASE_VOCAB_SIZE );
7577 this ->vocab_size = vocab_size;
76- int num_merges = vocab_size - BASE_VOCAB_SIZE ;
78+ int num_merges = vocab_size - BPE_TOKENIZER_BASE_VOCAB_SIZE ;
7779
7880 // Convert text to byte IDs
7981 std::vector<int > ids;
@@ -84,25 +86,25 @@ namespace dlib
8486 auto stats = get_stats (ids);
8587 if (stats.empty ()) break ;
8688
87- // Find the most frequent pair that does not exceed MAX_TOKEN_LENGTH
89+ // Find the most frequent pair that does not exceed BPE_TOKENIZER_MAX_TOKEN_LENGTH
8890 auto pair = get_most_frequent_pair (stats);
8991
90- // Check if the resulting token would exceed MAX_TOKEN_LENGTH
92+ // Check if the resulting token would exceed BPE_TOKENIZER_MAX_TOKEN_LENGTH
9193 size_t new_token_length = vocab[pair.first ].size () + vocab[pair.second ].size ();
92- if (new_token_length > MAX_TOKEN_LENGTH ) {
94+ if (new_token_length > BPE_TOKENIZER_MAX_TOKEN_LENGTH ) {
9395 if (verbose)
9496 {
9597 std::cout << " \r "
9698 << std::setw (100 ) << std::flush
9799 << " \r skipping merge " << std::to_string (i + 1 ) << " /" << std::to_string (num_merges) << " : ("
98100 << std::to_string (pair.first ) << " ," << std::to_string (pair.second ) << " ) -> new token length "
99- << std::to_string (new_token_length) << " exceeds limit of " << std::to_string (MAX_TOKEN_LENGTH )
101+ << std::to_string (new_token_length) << " exceeds limit of " << std::to_string (BPE_TOKENIZER_MAX_TOKEN_LENGTH )
100102 << std::flush;
101103 }
102104 continue ; // Skip this merge
103105 }
104106
105- int idx = (BASE_VOCAB_SIZE + (int )special_tokens.size ()) + i;
107+ int idx = (BPE_TOKENIZER_BASE_VOCAB_SIZE + (int )special_tokens.size ()) + i;
106108 ids = merge (ids, pair, idx);
107109 merges[pair] = idx;
108110 vocab[idx].insert (vocab[idx].end (), vocab[pair.first ].begin (), vocab[pair.first ].end ());
@@ -287,9 +289,6 @@ namespace dlib
287289 std::map<int , std::vector<uint8_t >> vocab;
288290 int vocab_size;
289291
290- static const size_t MAX_TOKEN_LENGTH = 8 ;
291- static const int BASE_VOCAB_SIZE = 256 ;
292-
293292 // Get frequency statistics of adjacent token pairs
294293 struct pair_hash {
295294 template <class T1 , class T2 >
@@ -344,10 +343,10 @@ namespace dlib
344343
345344 // Check if the new token formed by merging the pair would exceed the maximum allowed length
346345 size_t new_token_length = vocab.at (pair.first ).size () + vocab.at (pair.second ).size ();
347- if (new_token_length > MAX_TOKEN_LENGTH ) continue ; // Skip this pair if it exceeds the maximum token length
346+ if (new_token_length > BPE_TOKENIZER_MAX_TOKEN_LENGTH ) continue ; // Skip this pair if it exceeds the maximum token length
348347
349348 // Calculate the score for this pair (frequency * length_penalty)
350- double score = (size_t )count * (new_token_length > (MAX_TOKEN_LENGTH / 2 ) ? 1.75 : 1.0 );
349+ double score = (size_t )count * (new_token_length > (BPE_TOKENIZER_MAX_TOKEN_LENGTH / 2 ) ? 1.75 : 1.0 );
351350
352351 // Update the best pair if the current pair has a higher score
353352 if (score > max_score)
0 commit comments