@@ -1663,8 +1663,6 @@ struct llama_sampler * llama_sampler_init_penalties(
16631663// DRY
16641664
16651665struct llama_sampler_dry {
1666- const llama_model * model;
1667-
16681666 int32_t total_context_size;
16691667
16701668 const float dry_multiplier;
@@ -1679,10 +1677,9 @@ struct llama_sampler_dry {
16791677};
16801678
16811679// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1682- static void get_overlapping_token_sequences (const struct llama_model * model, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1 ) {
1683- const int n_vocab = llama_n_vocab (model);
1684- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
1685- std::string word = llama_detokenize (model, {token_id}, true );
1680+ static void get_overlapping_token_sequences (const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1 ) {
1681+ for (llama_token token_id = 0 ; token_id < (int )vocab.n_vocab ; token_id++) {
1682+ std::string word = llama_detokenize (vocab, {token_id}, true );
16861683 if (word.find (str) != std::string::npos) {
16871684 token_sequences.emplace (token_id, std::vector<llama_token>());
16881685 } else {
@@ -1698,7 +1695,7 @@ static void get_overlapping_token_sequences(const struct llama_model * model, co
16981695 }
16991696 }
17001697 if (match) {
1701- std::vector<llama_token> tokenization = llama_tokenize_internal (model , str.substr (i), false , false );
1698+ std::vector<llama_token> tokenization = llama_tokenize_internal (vocab , str.substr (i), false , false );
17021699 if (max_tail_len >= 0 && tokenization.size () > (size_t )max_tail_len) {
17031700 tokenization.resize (max_tail_len);
17041701 }
@@ -1951,8 +1948,7 @@ static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
19511948static struct llama_sampler * llama_sampler_dry_clone (const struct llama_sampler * smpl) {
19521949 const auto * ctx = (llama_sampler_dry *) smpl->ctx ;
19531950
1954- auto * result = llama_sampler_init_dry (ctx->model , ctx->total_context_size , ctx->dry_multiplier , ctx->dry_base , ctx->dry_allowed_length , ctx->dry_penalty_last_n , NULL , 0 );
1955-
1951+ auto * result = llama_sampler_init_dry (nullptr , ctx->total_context_size , ctx->dry_multiplier , ctx->dry_base , ctx->dry_allowed_length , ctx->dry_penalty_last_n , NULL , 0 );
19561952 // Copy the state, including the processed breakers
19571953 {
19581954 auto * result_ctx = (llama_sampler_dry *) result->ctx ;
@@ -1978,7 +1974,7 @@ static struct llama_sampler_i llama_sampler_dry_i = {
19781974 /* .free = */ llama_sampler_dry_free,
19791975};
19801976
1981- struct llama_sampler * llama_sampler_init_dry (const struct llama_model * model , int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char ** seq_breakers, size_t num_breakers) {
1977+ struct llama_sampler * llama_sampler_init_dry_impl (const struct llama_vocab & vocab , int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char ** seq_breakers, size_t num_breakers) {
19821978 if (dry_multiplier < 0 || dry_base <= 0 || dry_allowed_length < 0 ) {
19831979 return nullptr ;
19841980 }
@@ -2008,14 +2004,13 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model,
20082004 sequence_break.resize (MAX_CHAR_LEN);
20092005 }
20102006
2011- get_overlapping_token_sequences (model , sequence_break, processed_breakers, MAX_SEQ_LEN);
2007+ get_overlapping_token_sequences (vocab , sequence_break, processed_breakers, MAX_SEQ_LEN);
20122008 }
20132009 }
20142010
20152011 return new llama_sampler {
20162012 /* .iface = */ &llama_sampler_dry_i,
20172013 /* .ctx = */ new llama_sampler_dry {
2018- /* .model = */ model,
20192014 /* .total_context_size = */ context_size,
20202015 /* .dry_multiplier = */ dry_multiplier,
20212016 /* .dry_base = */ dry_base,
0 commit comments