Skip to content

Commit 12561a1

Browse files
committed
Some more cleanup
1 parent dfb6089 commit 12561a1

File tree

2 files changed

+41
-59
lines changed

2 files changed

+41
-59
lines changed

dlib/tokenizer/bpe_tokenizer.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ namespace dlib
7171
// Train the tokenizer on the given text
7272
void train(const std::string& text, int vocab_size, bool verbose = false)
7373
{
74-
assert(vocab_size >= BASE_VOCAB_SIZE);
74+
DLIB_CASSERT(vocab_size >= BASE_VOCAB_SIZE);
7575
this->vocab_size = vocab_size;
7676
int num_merges = vocab_size - BASE_VOCAB_SIZE;
7777

@@ -122,7 +122,7 @@ namespace dlib
122122
}
123123

124124
// Encode the given text into subword tokens
125-
std::vector<int> encode(const std::string& text)
125+
std::vector<int> encode(const std::string& text) const
126126
{
127127
std::vector<int> result_ids;
128128
std::mutex result_mutex;
@@ -210,13 +210,13 @@ namespace dlib
210210
}
211211

212212
// Decode a single token ID back into text
213-
std::string decode(int id, bool display_special_tokens = true)
213+
std::string decode(int id, bool display_special_tokens = true) const
214214
{
215215
return decode(std::vector<int>({ id }), display_special_tokens);
216216
}
217217

218218
// Decode a sequence of token IDs back into text
219-
std::string decode(const std::vector<int>& ids, bool display_special_tokens = true)
219+
std::string decode(const std::vector<int>& ids, bool display_special_tokens = true) const
220220
{
221221
std::vector<uint8_t> bytes;
222222
int vocab_size = static_cast<int>(get_vocab_size());
@@ -275,7 +275,7 @@ namespace dlib
275275
}
276276

277277
// Get the total vocabulary size
278-
size_t get_vocab_size(void) const
278+
size_t get_vocab_size() const
279279
{
280280
return (vocab.size() + special_tokens.size());
281281
}
@@ -300,7 +300,7 @@ namespace dlib
300300
return hash1 ^ (hash2 << 1);
301301
}
302302
};
303-
std::unordered_map<std::pair<int, int>, int, pair_hash> get_stats(const std::vector<int>& ids)
303+
std::unordered_map<std::pair<int, int>, int, pair_hash> get_stats(const std::vector<int>& ids) const
304304
{
305305
std::unordered_map<std::pair<int, int>, int, pair_hash> global_stats;
306306
std::mutex global_stats_mutex;
@@ -332,7 +332,8 @@ namespace dlib
332332
}
333333

334334
// Finds the most frequent pair of tokens in the given statistics map that does not exceed the maximum token length
335-
std::pair<int, int> get_most_frequent_pair(const std::unordered_map<std::pair<int, int>, int, pair_hash>& stats) {
335+
std::pair<int, int> get_most_frequent_pair(const std::unordered_map<std::pair<int, int>, int, pair_hash>& stats) const
336+
{
336337
std::pair<int, int> best_pair = { -1, -1 }; // Initialize the best pair to an invalid value
337338
double max_score = 0; // Initialize the maximum score to 0
338339

@@ -342,7 +343,7 @@ namespace dlib
342343
int count = stat.second; // Extract the frequency count
343344

344345
// Check if the new token formed by merging the pair would exceed the maximum allowed length
345-
size_t new_token_length = vocab[pair.first].size() + vocab[pair.second].size();
346+
size_t new_token_length = vocab.at(pair.first).size() + vocab.at(pair.second).size();
346347
if (new_token_length > MAX_TOKEN_LENGTH) continue; // Skip this pair if it exceeds the maximum token length
347348

348349
// Calculate the score for this pair (frequency * length_penalty)
@@ -360,7 +361,8 @@ namespace dlib
360361
}
361362

362363
// Merge the most frequent pair in the token sequence
363-
std::vector<int> merge(std::vector<int>& ids, const std::pair<int, int>& pair, int idx) {
364+
std::vector<int> merge(std::vector<int>& ids, const std::pair<int, int>& pair, int idx) const
365+
{
364366
std::vector<int> new_ids;
365367
new_ids.reserve(ids.size()); // Reserve space to avoid reallocations
366368

dlib/tokenizer/bpe_tokenizer_abstract.h

Lines changed: 30 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ namespace dlib
1919
class bpe_tokenizer
2020
{
2121
/*!
22-
CLASS bpe_tokenizer
23-
A Byte Pair Encoding (BPE) tokenizer for text processing.
24-
22+
WHAT THIS OBJECT REPRESENTS
2523
This class implements a Byte Pair Encoding (BPE) tokenizer, which is a subword
2624
tokenization algorithm commonly used in natural language processing (NLP). The
2725
BPE algorithm iteratively merges the most frequent pairs of bytes or characters
@@ -37,21 +35,17 @@ namespace dlib
3735
text into subword tokens, and decoding tokens back into text. The tokenizer can be
3836
serialized and deserialized to/from a file, allowing for easy storage and reuse.
3937
40-
INITIAL VALUE
41-
- The base vocabulary is initialized with single-byte tokens (0-255).
42-
- Special tokens are pre-defined and assigned IDs starting from 256.
43-
- The maximum token length is set to 8 bytes.
44-
45-
WHAT THIS OBJECT REPRESENTS
46-
This object represents a BPE tokenizer capable of encoding and decoding text
47-
using a learned subword vocabulary. It is designed to handle UTF-8 encoded text
48-
and supports multi-threaded processing for efficient tokenization.
49-
5038
REFERENCES
5139
- Sennrich, R., Haddow, B., & Birch, A. (2016). Neural Machine Translation of
5240
Rare Words with Subword Units. In Proceedings of the 54th Annual Meeting of
5341
the Association for Computational Linguistics (ACL 2016).
42+
43+
INITIAL VALUE
44+
- The base vocabulary is initialized with single-byte tokens (0-255).
45+
- Special tokens are pre-defined and assigned IDs starting from 256.
46+
- The maximum token length is set to 8 bytes.
5447
!*/
48+
5549
public:
5650
bpe_tokenizer();
5751
/*!
@@ -77,7 +71,7 @@ namespace dlib
7771

7872
std::vector<int> encode(
7973
const std::string& text
80-
);
74+
) const;
8175
/*!
8276
ensures
8377
- Encodes the input text into a sequence of subword tokens.
@@ -88,32 +82,19 @@ namespace dlib
8882
std::string decode(
8983
const std::vector<int>& ids,
9084
bool display_special_tokens = true
91-
);
85+
) const;
9286
/*!
9387
ensures
9488
- Decodes a sequence of token IDs back into a human-readable string.
9589
- If `display_special_tokens` is true, special tokens are included in the output.
9690
- Returns the decoded text as a UTF-8 encoded string.
9791
!*/
9892

99-
void serialize(
100-
const bpe_tokenizer& tok,
101-
std::ostream& out
102-
);
103-
/*!
104-
ensures
105-
- Serializes the tokenizer's vocabulary and merge operations to the output stream.
106-
- The serialized data can be used to reconstruct the tokenizer later.
107-
!*/
108-
109-
void deserialize(
110-
bpe_tokenizer& tok,
111-
std::istream& in
112-
);
93+
std::string decode(int id, bool display_special_tokens = true) const
94+
{ return decode(std::vector<int>({ id }), display_special_tokens); }
11395
/*!
11496
ensures
115-
- Deserializes the tokenizer's vocabulary and merge operations from the input stream.
116-
- Restores the tokenizer to the state it was in when serialized.
97+
- decode a single token back into text.
11798
!*/
11899

119100
int get_special_token_id(
@@ -130,26 +111,25 @@ namespace dlib
130111
ensures
131112
- Returns the total size of the vocabulary, including base tokens and special tokens.
132113
!*/
133-
134-
private:
135-
// Private implementation details
136-
std::map<std::string, int> special_tokens;
137-
std::unordered_map<int, std::string> special_token_map;
138-
std::map<std::pair<int, int>, int> merges;
139-
std::map<int, std::vector<uint8_t>> vocab;
140-
int vocab_size;
141-
142-
static const size_t MAX_TOKEN_LENGTH = 8;
143-
static const int BASE_VOCAB_SIZE = 256;
144-
145-
// Helper functions
146-
std::unordered_map<std::pair<int, int>, int, pair_hash> get_stats(const std::vector<int>& ids);
147-
std::pair<int, int> get_most_frequent_pair(const std::unordered_map<std::pair<int, int>, int, pair_hash>& stats);
148-
std::vector<int> merge(std::vector<int>& ids, const std::pair<int, int>& pair, int idx);
149-
std::string bytes_to_string(const std::vector<uint8_t>& bytes);
150-
std::vector<uint8_t> string_to_bytes(const std::string& str);
151114
};
152115

116+
void serialize(
117+
const bpe_tokenizer& tok,
118+
std::ostream& out
119+
);
120+
/*!
121+
ensures
122+
- Saves the entire state of tok to out.
123+
!*/
124+
125+
void deserialize(
126+
bpe_tokenizer& tok,
127+
std::istream& in
128+
);
129+
/*!
130+
ensures
131+
- Restores the state of a bpe_tokenizer from a serialized state.
132+
!*/
153133
}
154134

155-
#endif // DLIB_BPE_TOKENIZER_ABSTRACT_
135+
#endif // DLIB_BPE_TOKENIZER_ABSTRACT_

0 commit comments

Comments
 (0)