1+ // Copyright (C) 2025 Davis E. King ([email protected] )2+ // License: Boost Software License See LICENSE.txt for the full license.
13#ifndef DLIB_BPE_TOKENIZER_H
24#define DLIB_BPE_TOKENIZER_H
35
1416
1517#include " ../base64.h"
1618#include " ../serialize.h"
19+ #include " bpe_tokenizer_abstract.h"
1720
1821namespace dlib
1922{
@@ -113,10 +116,9 @@ namespace dlib
113116 << std::to_string (pair.first ) << " ," << std::to_string (pair.second ) << " ) -> " << std::to_string (idx)
114117 << " (" << bytes_to_string (vocab[idx]) << " ) had "
115118 << std::to_string (stats[pair]) << " occurrences"
116- << std::flush ;
119+ << std::endl ;
117120 }
118121 }
119- std::cout << " \n training done\n " ;
120122 }
121123
122124 // Encode the given text into subword tokens
@@ -243,65 +245,25 @@ namespace dlib
243245 // Save the tokenizer model and vocabulary to file
244246 friend void serialize (const bpe_tokenizer& tok, std::ostream& out)
245247 {
246- dlib::serialize (" bpe_tokenizer_" , out);
247-
248- // ---
249- int nb_merges = tok.merges .size ();
250- dlib::serialize (nb_merges, out);
251- for (int idx = (BASE_VOCAB_SIZE + (int )tok.special_tokens .size ());
252- idx < (tok.vocab_size + (int )tok.special_tokens .size ()); ++idx)
253- {
254- for (const auto & merge_pair : tok.merges )
255- {
256- if (merge_pair.second == idx)
257- {
258- dlib::serialize (merge_pair.first .first , out);
259- dlib::serialize (merge_pair.first .second , out);
260- break ;
261- }
262- }
263- }
264-
265- // ---
266- int nb_vocab = (int )tok.vocab .size ();
267- dlib::serialize (nb_vocab, out);
268- for (const auto & v : tok.vocab )
269- {
270- std::string token_str = tok.bytes_to_string (v.second );
271- dlib::serialize (token_str, out);
272- dlib::serialize (v.first , out);
273- }
248+ serialize (" bpe_tokenizer2_" , out);
249+ serialize (tok.special_tokens , out);
250+ serialize (tok.special_token_map , out);
251+ serialize (tok.merges , out);
252+ serialize (tok.vocab , out);
253+ serialize (tok.vocab_size , out);
274254 }
275255
276256 // Load the tokenizer model and vocabulary from file
277257 friend void deserialize (bpe_tokenizer& tok, std::istream& in) {
278258 std::string version;
279259 dlib::deserialize (version, in);
280- if (version != " bpe_tokenizer_ " )
260+ if (version != " bpe_tokenizer2_ " )
281261 throw dlib::serialization_error (" Unexpected version '" + version + " ' found while deserializing dlib::bpe_tokenizer_." );
282-
283- // ---
284- int idx = BASE_VOCAB_SIZE + (int )tok.special_tokens .size (), nb_merges, nb_vocab, a, b;
285- tok.merges .clear ();
286- dlib::deserialize (nb_merges, in);
287- for (int m = 0 ; m < nb_merges; m++)
288- {
289- dlib::deserialize (a, in);
290- dlib::deserialize (b, in);
291- tok.merges [{a, b}] = idx;
292- idx++;
293- }
294-
295- // ---
296- std::string token_str;
297- tok.vocab .clear ();
298- dlib::deserialize (nb_vocab, in);
299- for (int v = 0 ; v < nb_vocab; v++)
300- {
301- dlib::deserialize (token_str, in);
302- dlib::deserialize (idx, in);
303- tok.vocab [idx] = tok.string_to_bytes (token_str);
304- }
262+ deserialize (tok.special_tokens , in);
263+ deserialize (tok.special_token_map , in);
264+ deserialize (tok.merges , in);
265+ deserialize (tok.vocab , in);
266+ deserialize (tok.vocab_size , in);
305267 }
306268
307269 // Get the ID of a special token
@@ -415,15 +377,6 @@ namespace dlib
415377 return new_ids;
416378 }
417379
418- // Decode/Encode a base64 string to/from a UTF-8 string
419- static std::string base64_decode (const std::string& base64_str)
420- {
421- dlib::base64 decoder;
422- std::istringstream sin (base64_str);
423- std::ostringstream sout;
424- decoder.decode (sin, sout);
425- return sout.str ();
426- }
427380 static std::string base64_encode (const std::string& input) {
428381 dlib::base64 encoder;
429382 std::istringstream sin (input);
@@ -439,15 +392,9 @@ namespace dlib
439392 return base64_encode (data);
440393 }
441394
442- // Convert a string representation of bytes back to bytes
443- static std::vector<uint8_t > string_to_bytes (const std::string& str)
444- {
445- std::string decoded = base64_decode (str);
446- return std::vector<uint8_t >(decoded.begin (), decoded.end ());
447- }
448395 };
449396
450397}
451398
452399
453- #endif // DLIB_BPE_TOKENIZER_H
400+ #endif // DLIB_BPE_TOKENIZER_H
0 commit comments