Skip to content

Commit 66343ab

Browse files
committed
llama : vocab pimpl
ggml-ci
1 parent 9d6f9df commit 66343ab

File tree

2 files changed

+62
-58
lines changed

2 files changed

+62
-58
lines changed

src/llama-vocab.cpp

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,18 @@ struct naive_trie {
6666
//
6767

6868
struct llm_tokenizer {
69-
llm_tokenizer() {}
70-
virtual ~llm_tokenizer() = default;
69+
llm_tokenizer() {}
70+
virtual ~llm_tokenizer() = default;
7171
};
7272

73+
struct llama_vocab::impl {
74+
std::unique_ptr<llm_tokenizer> tokenizer;
75+
};
76+
77+
llama_vocab::llama_vocab() : pimpl(new impl()) {
78+
}
79+
7380
llama_vocab::~llama_vocab() {
74-
delete tokenizer;
7581
}
7682

7783
int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
@@ -194,7 +200,7 @@ struct llm_bigram_spm {
194200
};
195201

196202
struct llm_tokenizer_spm : llm_tokenizer {
197-
llm_tokenizer_spm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
203+
llm_tokenizer_spm(const llama_vocab & /*vocab*/) {}
198204
};
199205

200206
struct llm_tokenizer_spm_session {
@@ -364,7 +370,7 @@ struct llm_bigram_bpe {
364370
};
365371

366372
struct llm_tokenizer_bpe : llm_tokenizer {
367-
llm_tokenizer_bpe(const llama_vocab & vocab) : llm_tokenizer() {
373+
llm_tokenizer_bpe(const llama_vocab & vocab) {
368374
GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
369375
switch (vocab.type_pre) {
370376
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
@@ -499,8 +505,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
499505
};
500506

501507
struct llm_tokenizer_bpe_session {
502-
llm_tokenizer_bpe_session(const llama_vocab & vocab) : vocab(vocab),
503-
bpe_tokenizer(static_cast<const llm_tokenizer_bpe *>(vocab.tokenizer)) {}
508+
llm_tokenizer_bpe_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : vocab(vocab), tokenizer(tokenizer) {}
504509

505510
static void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) {
506511
output.push_back(token_id);
@@ -541,7 +546,7 @@ struct llm_tokenizer_bpe_session {
541546

542547
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
543548
int final_prev_index = -1;
544-
const auto word_collection = unicode_regex_split(text, bpe_tokenizer->regex_exprs);
549+
const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs);
545550

546551
symbols_final.clear();
547552

@@ -671,7 +676,7 @@ struct llm_tokenizer_bpe_session {
671676
}
672677

673678
const llama_vocab & vocab;
674-
const llm_tokenizer_bpe * bpe_tokenizer;
679+
const llm_tokenizer_bpe & tokenizer;
675680

676681
std::vector<llm_symbol> symbols;
677682
std::vector<llm_symbol> symbols_final;
@@ -683,7 +688,7 @@ struct llm_tokenizer_bpe_session {
683688
//
684689

685690
struct llm_tokenizer_wpm : llm_tokenizer {
686-
llm_tokenizer_wpm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
691+
llm_tokenizer_wpm(const llama_vocab & /*vocab*/) {}
687692
};
688693

689694
struct llm_tokenizer_wpm_session {
@@ -800,7 +805,7 @@ struct llm_tokenizer_wpm_session {
800805
//
801806

802807
struct llm_tokenizer_ugm : llm_tokenizer {
803-
llm_tokenizer_ugm(const llama_vocab & vocab) : llm_tokenizer() {
808+
llm_tokenizer_ugm(const llama_vocab & vocab) {
804809
if (vocab.precompiled_charsmap.size() > 0) {
805810
size_t charsmap_offset = 0;
806811

@@ -867,8 +872,7 @@ struct llm_tokenizer_ugm : llm_tokenizer {
867872
};
868873

869874
struct llm_tokenizer_ugm_session {
870-
llm_tokenizer_ugm_session(const llama_vocab & vocab) : vocab(vocab),
871-
ugm_tokenizer(static_cast<const llm_tokenizer_ugm *>(vocab.tokenizer)) {}
875+
llm_tokenizer_ugm_session(const llama_vocab & vocab, const llm_tokenizer_ugm & tokenizer) : vocab(vocab), tokenizer(tokenizer) {}
872876

873877
/* This implementation is based on SentencePiece optimized Viterbi algorithm for
874878
* unigram language models. The general idea is to:
@@ -908,7 +912,7 @@ struct llm_tokenizer_ugm_session {
908912
// traverse the token matcher trie to find a matching token
909913
bool single_codepoint_token_found = false;
910914
const struct best_tokenization & current_best = tokenization_results[input_offset];
911-
const struct naive_trie * node = ugm_tokenizer->token_matcher.traverse(normalized[prefix_offset++]);
915+
const struct naive_trie * node = tokenizer.token_matcher.traverse(normalized[prefix_offset++]);
912916

913917
while (prefix_offset <= input_len && node != NULL) {
914918
// check if we found valid token in prefix
@@ -938,7 +942,7 @@ struct llm_tokenizer_ugm_session {
938942
// if we didn't find a valid token corresponding to the whole UTF code point
939943
// then use unknown token as the tokenization of this UTF code point
940944
if (!single_codepoint_token_found) {
941-
const double challenger_score = current_best.score_sum + ugm_tokenizer->unknown_token_score;
945+
const double challenger_score = current_best.score_sum + tokenizer.unknown_token_score;
942946
prefix_offset = input_offset + n_utf8_code_units;
943947
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
944948
if (challenger_score > current_champ.score_sum) {
@@ -982,7 +986,7 @@ struct llm_tokenizer_ugm_session {
982986
normalized->clear();
983987
normalized->reserve(input.size() * 3);
984988

985-
const std::string space = vocab.tokenizer_escape_whitespaces ? ugm_tokenizer->escaped_space : " ";
989+
const std::string space = vocab.tokenizer_escape_whitespaces ? tokenizer.escaped_space : " ";
986990

987991
bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
988992
bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
@@ -1078,16 +1082,16 @@ struct llm_tokenizer_ugm_session {
10781082

10791083
// if input prefix matches some user-defined token return this token as normalization result
10801084
auto user_defined_token_match =
1081-
ugm_tokenizer->user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
1085+
tokenizer.user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
10821086
if (user_defined_token_match.second > 0) {
10831087
return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
10841088
}
10851089

10861090
size_t longest_prefix_length = 0;
10871091
size_t longest_prefix_offset = 0;
10881092

1089-
if (ugm_tokenizer->xcda_array_size > 0) {
1090-
struct xcda_array_view xcda_view(ugm_tokenizer->xcda_array, ugm_tokenizer->xcda_array_size);
1093+
if (tokenizer.xcda_array_size > 0) {
1094+
struct xcda_array_view xcda_view(tokenizer.xcda_array, tokenizer.xcda_array_size);
10911095

10921096
// Find the longest normalized sequence matching the input prefix by walking
10931097
// the XOR-compressed compact double array (XCDA) starting from the root node
@@ -1123,10 +1127,10 @@ struct llm_tokenizer_ugm_session {
11231127

11241128
if (longest_prefix_length > 0) {
11251129
// we have a match, so return the replacement sequence
1126-
if (longest_prefix_offset >= ugm_tokenizer->prefix_replacements_size) {
1130+
if (longest_prefix_offset >= tokenizer.prefix_replacements_size) {
11271131
throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
11281132
}
1129-
const char * prefix_replacement = &(ugm_tokenizer->prefix_replacements)[longest_prefix_offset];
1133+
const char * prefix_replacement = &(tokenizer.prefix_replacements)[longest_prefix_offset];
11301134
return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
11311135
}
11321136

@@ -1143,7 +1147,7 @@ struct llm_tokenizer_ugm_session {
11431147
}
11441148

11451149
const llama_vocab & vocab;
1146-
const llm_tokenizer_ugm * ugm_tokenizer;
1150+
const llm_tokenizer_ugm & tokenizer;
11471151
};
11481152

11491153
//
@@ -1205,7 +1209,7 @@ static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escape
12051209
}
12061210

12071211
struct llm_tokenizer_rwkv : llm_tokenizer {
1208-
llm_tokenizer_rwkv(const llama_vocab & vocab) : llm_tokenizer() {
1212+
llm_tokenizer_rwkv(const llama_vocab & vocab) {
12091213
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
12101214
// For now, we decode the vocab here into the lookup we'll use for tokenization.
12111215

@@ -1221,13 +1225,12 @@ struct llm_tokenizer_rwkv : llm_tokenizer {
12211225
};
12221226

12231227
struct llm_tokenizer_rwkv_session {
1224-
llm_tokenizer_rwkv_session(const llama_vocab & vocab) : vocab(vocab),
1225-
rwkv_tokenizer(static_cast<const llm_tokenizer_rwkv &>(*vocab.tokenizer)) {}
1228+
llm_tokenizer_rwkv_session(const llama_vocab & vocab, const llm_tokenizer_rwkv & tokenizer) : vocab(vocab), tokenizer(tokenizer) {}
12261229

12271230
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
12281231
uint32_t position = 0;
12291232
while (position < text.size()) {
1230-
const struct naive_trie * node = rwkv_tokenizer.token_matcher.traverse(text[position]);
1233+
const struct naive_trie * node = tokenizer.token_matcher.traverse(text[position]);
12311234
if (node == NULL) {
12321235
// no matching token found, add unknown token
12331236
output.push_back(vocab.special_unk_id);
@@ -1254,25 +1257,25 @@ struct llm_tokenizer_rwkv_session {
12541257

12551258
private:
12561259
const llama_vocab & vocab;
1257-
const llm_tokenizer_rwkv & rwkv_tokenizer;
1260+
const llm_tokenizer_rwkv & tokenizer;
12581261
};
12591262

12601263
void llama_vocab::init_tokenizer() {
12611264
switch (type) {
12621265
case LLAMA_VOCAB_TYPE_SPM:
1263-
tokenizer = new llm_tokenizer_spm(*this);
1266+
pimpl->tokenizer = std::make_unique<llm_tokenizer_spm>(*this);
12641267
break;
12651268
case LLAMA_VOCAB_TYPE_BPE:
1266-
tokenizer = new llm_tokenizer_bpe(*this);
1269+
pimpl->tokenizer = std::make_unique<llm_tokenizer_bpe>(*this);
12671270
break;
12681271
case LLAMA_VOCAB_TYPE_WPM:
1269-
tokenizer = new llm_tokenizer_wpm(*this);
1272+
pimpl->tokenizer = std::make_unique<llm_tokenizer_wpm>(*this);
12701273
break;
12711274
case LLAMA_VOCAB_TYPE_UGM:
1272-
tokenizer = new llm_tokenizer_ugm(*this);
1275+
pimpl->tokenizer = std::make_unique<llm_tokenizer_ugm>(*this);
12731276
break;
12741277
case LLAMA_VOCAB_TYPE_RWKV:
1275-
tokenizer = new llm_tokenizer_rwkv(*this);
1278+
pimpl->tokenizer = std::make_unique<llm_tokenizer_rwkv>(*this);
12761279
break;
12771280
default:
12781281
GGML_ABORT("unsupported vocab type");
@@ -1566,7 +1569,7 @@ std::vector<llama_vocab::id> llama_vocab::tokenize(
15661569
std::string raw_text,
15671570
bool add_special,
15681571
bool parse_special) const {
1569-
GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
1572+
GGML_ASSERT(pimpl->tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
15701573

15711574
std::vector<id> output;
15721575
std::forward_list<fragment_buffer_variant> fragment_buffer;
@@ -1628,7 +1631,7 @@ std::vector<llama_vocab::id> llama_vocab::tokenize(
16281631
} break;
16291632
case LLAMA_VOCAB_TYPE_BPE:
16301633
{
1631-
llm_tokenizer_bpe_session session(*this);
1634+
llm_tokenizer_bpe_session session(*this, *static_cast<const llm_tokenizer_bpe *>(pimpl->tokenizer.get()));
16321635
// it calls some other methods that are not exist in llm_tokenizer,
16331636
// here just cast it to bpe tokenizer object
16341637
if (add_special) {
@@ -1685,7 +1688,7 @@ std::vector<llama_vocab::id> llama_vocab::tokenize(
16851688
GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
16861689
output.push_back(special_bos_id);
16871690
}
1688-
llm_tokenizer_ugm_session session(*this);
1691+
llm_tokenizer_ugm_session session(*this, *static_cast<const llm_tokenizer_ugm *>(pimpl->tokenizer.get()));
16891692

16901693
for (const auto & fragment : fragment_buffer) {
16911694
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@@ -1713,7 +1716,7 @@ std::vector<llama_vocab::id> llama_vocab::tokenize(
17131716
} break;
17141717
case LLAMA_VOCAB_TYPE_RWKV:
17151718
{
1716-
llm_tokenizer_rwkv_session session(*this);
1719+
llm_tokenizer_rwkv_session session(*this, *static_cast<const llm_tokenizer_rwkv *>(pimpl->tokenizer.get()));
17171720
for (const auto & fragment : fragment_buffer) {
17181721
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
17191722
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@@ -1872,7 +1875,7 @@ int32_t llama_vocab::detokenize(
18721875
return 0;
18731876
}
18741877

1875-
GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
1878+
GGML_ASSERT(pimpl->tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
18761879

18771880
int32_t avail = text_len_max;
18781881
int32_t total = 0;

src/llama-vocab.h

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
#include <unordered_map>
88
#include <map>
99
#include <set>
10-
11-
struct llm_tokenizer;
10+
#include <memory>
1211

1312
struct llama_vocab {
1413
using id = llama_token;
@@ -73,9 +72,7 @@ struct llama_vocab {
7372

7473
std::vector<char> precompiled_charsmap;
7574

76-
llm_tokenizer * tokenizer = nullptr;
77-
78-
llama_vocab() = default;
75+
llama_vocab();
7976
~llama_vocab();
8077

8178
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
@@ -131,30 +128,30 @@ struct llama_vocab {
131128
bool add_eos_token() const;
132129

133130
std::vector<id> tokenize(
134-
std::string raw_text,
135-
bool add_special,
136-
bool parse_special = false) const;
131+
std::string raw_text,
132+
bool add_special,
133+
bool parse_special = false) const;
137134

138135
int32_t tokenize(
139-
const char * text,
140-
int32_t text_len,
141-
llama_token * tokens,
142-
int32_t n_tokens_max,
143-
bool add_special,
144-
bool parse_special) const;
136+
const char * text,
137+
int32_t text_len,
138+
llama_token * tokens,
139+
int32_t n_tokens_max,
140+
bool add_special,
141+
bool parse_special) const;
145142

146143
// does not write null-terminator to buf
147144
int32_t token_to_piece(
148-
llama_token token,
149-
char * buf,
150-
int32_t length,
151-
int32_t lstrip,
152-
bool special) const;
145+
llama_token token,
146+
char * buf,
147+
int32_t length,
148+
int32_t lstrip,
149+
bool special) const;
153150

154151
// check if token0 is contained as a prefix in token1
155152
bool token_is_prefix(
156-
llama_token token0,
157-
llama_token token1) const;
153+
llama_token token0,
154+
llama_token token1) const;
158155

159156
int32_t detokenize(
160157
const llama_token * tokens,
@@ -167,4 +164,8 @@ struct llama_vocab {
167164
std::string detokenize(
168165
const std::vector<llama_token> & tokens,
169166
bool special) const;
167+
168+
private:
169+
struct impl;
170+
std::unique_ptr<impl> pimpl;
170171
};

0 commit comments

Comments
 (0)