Skip to content

Commit 304c815

Browse files
committed
DRY: Using vocab instead of model
1 parent 4dc9fc3 commit 304c815

File tree

7 files changed

+50
-58
lines changed

7 files changed

+50
-58
lines changed

src/llama-impl.h

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -70,32 +70,6 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
7070
struct llama_context * ctx
7171
);
7272

73-
// exposing wrapper function that takes "model" instead of "vocab", to be used internally
74-
std::vector<llama_token> llama_tokenize_internal(
75-
const struct llama_model * model,
76-
const std::string & raw_text,
77-
bool add_special = false,
78-
bool parse_special = true);
79-
80-
static std::string llama_detokenize(const struct llama_model * model, const std::vector<llama_token> & tokens, bool special) {
81-
if (model == nullptr) { // model is passed as nullptr in test-sampling.cpp
82-
return "";
83-
}
84-
std::string text;
85-
text.resize(std::max(text.capacity(), tokens.size()));
86-
int32_t n_chars = llama_detokenize(model, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
87-
if (n_chars < 0) {
88-
text.resize(-n_chars);
89-
n_chars = llama_detokenize(model, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
90-
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
91-
}
92-
93-
text.resize(n_chars);
94-
95-
// NOTE: the original tokenizer decodes bytes after collecting the pieces.
96-
return text;
97-
}
98-
9973
// the ring buffer works similarly to std::deque, but with a fixed capacity
10074
template<typename T>
10175
struct ring_buffer {

src/llama-sampling.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,8 +1663,6 @@ struct llama_sampler * llama_sampler_init_penalties(
16631663
// DRY
16641664

16651665
struct llama_sampler_dry {
1666-
const llama_model * model;
1667-
16681666
int32_t total_context_size;
16691667

16701668
const float dry_multiplier;
@@ -1679,10 +1677,9 @@ struct llama_sampler_dry {
16791677
};
16801678

16811679
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1682-
static void get_overlapping_token_sequences(const struct llama_model * model, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
1683-
const int n_vocab = llama_n_vocab(model);
1684-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
1685-
std::string word = llama_detokenize(model, {token_id}, true);
1680+
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
1681+
for (llama_token token_id = 0; token_id < (int)vocab.n_vocab; token_id++) {
1682+
std::string word = llama_detokenize(vocab, {token_id}, true);
16861683
if (word.find(str) != std::string::npos) {
16871684
token_sequences.emplace(token_id, std::vector<llama_token>());
16881685
} else {
@@ -1698,7 +1695,7 @@ static void get_overlapping_token_sequences(const struct llama_model * model, co
16981695
}
16991696
}
17001697
if (match) {
1701-
std::vector<llama_token> tokenization = llama_tokenize_internal(model, str.substr(i), false, false);
1698+
std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
17021699
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
17031700
tokenization.resize(max_tail_len);
17041701
}
@@ -1951,8 +1948,7 @@ static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
19511948
static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
19521949
const auto * ctx = (llama_sampler_dry *) smpl->ctx;
19531950

1954-
auto * result = llama_sampler_init_dry(ctx->model, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
1955-
1951+
auto * result = llama_sampler_init_dry(nullptr, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
19561952
// Copy the state, including the processed breakers
19571953
{
19581954
auto * result_ctx = (llama_sampler_dry *) result->ctx;
@@ -1978,7 +1974,7 @@ static struct llama_sampler_i llama_sampler_dry_i = {
19781974
/* .free = */ llama_sampler_dry_free,
19791975
};
19801976

1981-
struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
1977+
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
19821978
if (dry_multiplier < 0 || dry_base <= 0 || dry_allowed_length < 0) {
19831979
return nullptr;
19841980
}
@@ -2008,14 +2004,13 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model,
20082004
sequence_break.resize(MAX_CHAR_LEN);
20092005
}
20102006

2011-
get_overlapping_token_sequences(model, sequence_break, processed_breakers, MAX_SEQ_LEN);
2007+
get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
20122008
}
20132009
}
20142010

20152011
return new llama_sampler {
20162012
/* .iface = */ &llama_sampler_dry_i,
20172013
/* .ctx = */ new llama_sampler_dry {
2018-
/* .model = */ model,
20192014
/* .total_context_size = */ context_size,
20202015
/* .dry_multiplier = */ dry_multiplier,
20212016
/* .dry_base = */ dry_base,

src/llama-sampling.h

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
44

55
#include "llama-grammar.h"
6-
#include "llama-impl.h"
76

87
struct llama_vocab;
98
struct llama_grammar;
@@ -30,11 +29,21 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
3029
struct llama_sampler * llama_sampler_init_infill_impl(
3130
const struct llama_vocab & vocab);
3231

32+
struct llama_sampler * llama_sampler_init_dry_impl(
33+
const struct llama_vocab & vocab,
34+
int32_t context_size,
35+
float dry_multiplier,
36+
float dry_base,
37+
int32_t dry_allowed_length,
38+
int32_t dry_penalty_last_n,
39+
const char ** seq_breakers,
40+
size_t num_breakers);
41+
3342
struct llama_sampler * llama_sampler_init_dry(
34-
const struct llama_model * model,
35-
int32_t context_size,
36-
float dry_multiplier,
37-
float dry_base,
38-
int32_t dry_allowed_length,
39-
int32_t dry_penalty_last_n,
40-
const std::vector<std::vector<llama_token>>& seq_breakers);
43+
const struct llama_model * model,
44+
int32_t context_size,
45+
float dry_multiplier,
46+
float dry_base,
47+
int32_t dry_allowed_length,
48+
int32_t dry_penalty_last_n,
49+
const std::vector<std::vector<llama_token>>& seq_breakers);

src/llama-vocab.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1966,3 +1966,19 @@ int32_t llama_detokenize_impl(
19661966

19671967
return total <= text_len_max ? total : -total;
19681968
}
1969+
1970+
std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector<llama_token> & tokens, bool special) {
1971+
std::string text;
1972+
text.resize(std::max(text.capacity(), tokens.size()));
1973+
int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
1974+
if (n_chars < 0) {
1975+
text.resize(-n_chars);
1976+
n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
1977+
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
1978+
}
1979+
1980+
text.resize(n_chars);
1981+
1982+
// NOTE: the original tokenizer decodes bytes after collecting the pieces.
1983+
return text;
1984+
}

src/llama-vocab.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,8 @@ int32_t llama_detokenize_impl(
163163
int32_t text_len_max,
164164
bool remove_special,
165165
bool unparse_special);
166+
167+
std::string llama_detokenize(
168+
const struct llama_vocab & vocab,
169+
const std::vector<llama_token> & tokens,
170+
bool special);

src/llama.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21452,15 +21452,6 @@ int32_t llama_tokenize(
2145221452
return llama_tokenize_impl(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special);
2145321453
}
2145421454

21455-
// wrapper function that takes "model" instead of "vocab", to be used internally
21456-
std::vector<llama_token> llama_tokenize_internal(
21457-
const struct llama_model * model,
21458-
const std::string & raw_text,
21459-
bool add_special,
21460-
bool parse_special) {
21461-
return llama_tokenize_internal(model->vocab, raw_text, add_special, parse_special);
21462-
}
21463-
2146421455
int32_t llama_token_to_piece(
2146521456
const struct llama_model * model,
2146621457
llama_token token,
@@ -21805,6 +21796,10 @@ struct llama_sampler * llama_sampler_init_infill(const struct llama_model * mode
2180521796
return llama_sampler_init_infill_impl(model->vocab);
2180621797
}
2180721798

21799+
struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
21800+
return llama_sampler_init_dry_impl(model->vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers);
21801+
}
21802+
2180821803
//
2180921804
// model split
2181021805
//

tests/test-sampling.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,8 @@ static void test_dry(
208208
}
209209

210210
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
211-
const int32_t context_size = 1024;
212-
struct llama_model * model = nullptr;
213211

214-
struct llama_sampler * sampler = llama_sampler_init_dry(model, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers);
212+
struct llama_sampler * sampler = llama_sampler_init_dry(nullptr, 1024, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers);
215213

216214
for (size_t i = 0; i < last_tokens.size(); i++) {
217215
llama_sampler_accept(sampler, last_tokens[i]);

0 commit comments

Comments
 (0)