Skip to content

Commit dfb6089

Browse files
committed
cleanup serialization code and add missing fields
1 parent 1cd0634 commit dfb6089

File tree

2 files changed

+19
-72
lines changed

2 files changed

+19
-72
lines changed

dlib/test/tokenizer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ namespace
384384
adopted in many state-of-the-art NLP models, including GPT and BERT.
385385
)";
386386

387-
test.train(training_text, 300, true);
387+
test.train(training_text, 300);
388388

389389
std::ostringstream out_stream;
390390
serialize(test, out_stream);
@@ -429,4 +429,4 @@ namespace
429429
}
430430
} a;
431431

432-
}
432+
}

dlib/tokenizer/bpe_tokenizer.h

Lines changed: 17 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
// Copyright (C) 2025 Davis E. King ([email protected])
2+
// License: Boost Software License See LICENSE.txt for the full license.
13
#ifndef DLIB_BPE_TOKENIZER_H
24
#define DLIB_BPE_TOKENIZER_H
35

@@ -14,6 +16,7 @@
1416

1517
#include "../base64.h"
1618
#include "../serialize.h"
19+
#include "bpe_tokenizer_abstract.h"
1720

1821
namespace dlib
1922
{
@@ -113,10 +116,9 @@ namespace dlib
113116
<< std::to_string(pair.first) << "," << std::to_string(pair.second) << ") -> " << std::to_string(idx)
114117
<< " (" << bytes_to_string(vocab[idx]) << ") had "
115118
<< std::to_string(stats[pair]) << " occurrences"
116-
<< std::flush;
119+
<< std::endl;
117120
}
118121
}
119-
std::cout << "\ntraining done\n";
120122
}
121123

122124
// Encode the given text into subword tokens
@@ -243,65 +245,25 @@ namespace dlib
243245
// Save the tokenizer model and vocabulary to file
244246
friend void serialize(const bpe_tokenizer& tok, std::ostream& out)
245247
{
246-
dlib::serialize("bpe_tokenizer_", out);
247-
248-
//---
249-
int nb_merges = tok.merges.size();
250-
dlib::serialize(nb_merges, out);
251-
for (int idx = (BASE_VOCAB_SIZE + (int)tok.special_tokens.size());
252-
idx < (tok.vocab_size + (int)tok.special_tokens.size()); ++idx)
253-
{
254-
for (const auto& merge_pair : tok.merges)
255-
{
256-
if (merge_pair.second == idx)
257-
{
258-
dlib::serialize(merge_pair.first.first, out);
259-
dlib::serialize(merge_pair.first.second, out);
260-
break;
261-
}
262-
}
263-
}
264-
265-
//---
266-
int nb_vocab = (int)tok.vocab.size();
267-
dlib::serialize(nb_vocab, out);
268-
for (const auto& v : tok.vocab)
269-
{
270-
std::string token_str = tok.bytes_to_string(v.second);
271-
dlib::serialize(token_str, out);
272-
dlib::serialize(v.first, out);
273-
}
248+
serialize("bpe_tokenizer2_", out);
249+
serialize(tok.special_tokens, out);
250+
serialize(tok.special_token_map, out);
251+
serialize(tok.merges, out);
252+
serialize(tok.vocab, out);
253+
serialize(tok.vocab_size, out);
274254
}
275255

276256
// Load the tokenizer model and vocabulary from file
277257
friend void deserialize(bpe_tokenizer& tok, std::istream& in) {
278258
std::string version;
279259
dlib::deserialize(version, in);
280-
if (version != "bpe_tokenizer_")
260+
if (version != "bpe_tokenizer2_")
281261
throw dlib::serialization_error("Unexpected version '" + version + "' found while deserializing dlib::bpe_tokenizer_.");
282-
283-
//---
284-
int idx = BASE_VOCAB_SIZE + (int)tok.special_tokens.size(), nb_merges, nb_vocab, a, b;
285-
tok.merges.clear();
286-
dlib::deserialize(nb_merges, in);
287-
for (int m = 0; m < nb_merges; m++)
288-
{
289-
dlib::deserialize(a, in);
290-
dlib::deserialize(b, in);
291-
tok.merges[{a, b}] = idx;
292-
idx++;
293-
}
294-
295-
//---
296-
std::string token_str;
297-
tok.vocab.clear();
298-
dlib::deserialize(nb_vocab, in);
299-
for (int v = 0; v < nb_vocab; v++)
300-
{
301-
dlib::deserialize(token_str, in);
302-
dlib::deserialize(idx, in);
303-
tok.vocab[idx] = tok.string_to_bytes(token_str);
304-
}
262+
deserialize(tok.special_tokens, in);
263+
deserialize(tok.special_token_map, in);
264+
deserialize(tok.merges, in);
265+
deserialize(tok.vocab, in);
266+
deserialize(tok.vocab_size, in);
305267
}
306268

307269
// Get the ID of a special token
@@ -415,15 +377,6 @@ namespace dlib
415377
return new_ids;
416378
}
417379

418-
// Decode/Encode a base64 string to/from a UTF-8 string
419-
static std::string base64_decode(const std::string& base64_str)
420-
{
421-
dlib::base64 decoder;
422-
std::istringstream sin(base64_str);
423-
std::ostringstream sout;
424-
decoder.decode(sin, sout);
425-
return sout.str();
426-
}
427380
static std::string base64_encode(const std::string& input) {
428381
dlib::base64 encoder;
429382
std::istringstream sin(input);
@@ -439,15 +392,9 @@ namespace dlib
439392
return base64_encode(data);
440393
}
441394

442-
// Convert a string representation of bytes back to bytes
443-
static std::vector<uint8_t> string_to_bytes(const std::string& str)
444-
{
445-
std::string decoded = base64_decode(str);
446-
return std::vector<uint8_t>(decoded.begin(), decoded.end());
447-
}
448395
};
449396

450397
}
451398

452399

453-
#endif // DLIB_BPE_TOKENIZER_H
400+
#endif // DLIB_BPE_TOKENIZER_H

0 commit comments

Comments
 (0)