@@ -1679,7 +1679,7 @@ struct llama_sampler_dry {
16791679};
16801680
16811681// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1682- static void GetOverlappingTokenSequences (const struct llama_model * model, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1 ) {
1682+ static void get_overlapping_token_sequences (const struct llama_model * model, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1 ) {
16831683 const int n_vocab = llama_n_vocab (model);
16841684 for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
16851685 std::string word = llama_detokenize (model, {token_id}, true );
@@ -1721,8 +1721,6 @@ static void GetOverlappingTokenSequences(const struct llama_model * model, const
17211721 }
17221722}
17231723
1724-
1725-
17261724static const char * llama_sampler_dry_name (const struct llama_sampler * /* smpl*/ ) {
17271725 return " dry" ;
17281726}
@@ -1952,9 +1950,10 @@ static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
19521950
19531951static struct llama_sampler * llama_sampler_dry_clone (const struct llama_sampler * smpl) {
19541952 const auto * ctx = (llama_sampler_dry *) smpl->ctx ;
1955- auto * result = llama_sampler_init_dry (ctx->model , ctx->total_context_size , ctx->dry_multiplier , ctx->dry_base , ctx->dry_allowed_length , ctx->dry_penalty_last_n );
19561953
1957- // copy the state
1954+ auto * result = llama_sampler_init_dry (ctx->model , ctx->total_context_size , ctx->dry_multiplier , ctx->dry_base , ctx->dry_allowed_length , ctx->dry_penalty_last_n , NULL , 0 );
1955+
1956+ // Copy the state, including the processed breakers
19581957 {
19591958 auto * result_ctx = (llama_sampler_dry *) result->ctx ;
19601959 result_ctx->dry_processed_breakers = ctx->dry_processed_breakers ;
@@ -1979,44 +1978,16 @@ static struct llama_sampler_i llama_sampler_dry_i = {
19791978 /* .free = */ llama_sampler_dry_free,
19801979};
19811980
1982- struct llama_sampler * llama_sampler_init_dry (const struct llama_model * model, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n) {
1981+ struct llama_sampler * llama_sampler_init_dry (const struct llama_model * model, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::string>& seq_breakers ) {
19831982 if (dry_multiplier < 0 || dry_base <= 0 || dry_allowed_length < 0 ) {
19841983 return nullptr ;
19851984 }
19861985
19871986 int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1 ) ? context_size : std::max (dry_penalty_last_n, 0 );
19881987
1989- return new llama_sampler {
1990- /* .iface = */ &llama_sampler_dry_i,
1991- /* .ctx = */ new llama_sampler_dry {
1992- /* .model = */ model,
1993- /* .total_context_size = */ context_size,
1994- /* .dry_multiplier = */ dry_multiplier,
1995- /* .dry_base = */ dry_base,
1996- /* .dry_allowed_length = */ dry_allowed_length,
1997- /* .dry_penalty_last_n = */ dry_penalty_last_n,
1998- /* .dry_processed_breakers = */ {},
1999- /* .dry_repeat_count = */ std::vector<int >(effective_dry_penalty_last_n, 0 ),
2000- /* .dry_max_token_repeat = */ {},
2001- /* .last_tokens = */ ring_buffer<llama_token>(effective_dry_penalty_last_n),
2002- },
2003- };
2004- }
2005-
2006- void llama_sampler_dry_set_seq_breakers (struct llama_sampler * smpl, const std::vector<std::string>& seq_breakers) {
2007- if (smpl == nullptr || smpl->ctx == nullptr ) {
2008- LLAMA_LOG_ERROR (" invalid sampler or context in llama_sampler_dry_set_seq_breakers" );
2009- return ;
2010- }
2011-
2012- auto * ctx = (llama_sampler_dry *) smpl->ctx ;
2013- ctx->dry_processed_breakers .clear ();
2014-
2015- if (seq_breakers.empty ()) {
2016- LLAMA_LOG_WARN (" empty sequence breakers list in llama_sampler_dry_set_seq_breakers" );
2017- return ;
2018- }
1988+ std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
20191989
1990+ // Process sequence breakers
20201991 const int MAX_CHAR_LEN = 40 ;
20211992 const int MAX_SEQ_LEN = 20 ;
20221993
@@ -2032,72 +2003,74 @@ void llama_sampler_dry_set_seq_breakers(struct llama_sampler * smpl, const std::
20322003 trimmed_break.resize (MAX_CHAR_LEN);
20332004 }
20342005
2035- GetOverlappingTokenSequences (ctx-> model , trimmed_break, ctx-> dry_processed_breakers , MAX_SEQ_LEN);
2006+ get_overlapping_token_sequences ( model, trimmed_break, processed_breakers , MAX_SEQ_LEN);
20362007 }
2037- }
20382008
2039- // For C-interface
2040- void llama_sampler_dry_set_seq_breakers_c (struct llama_sampler * smpl, const char **seq_breakers, int num_breakers) {
2041- if (smpl == nullptr || smpl->ctx == nullptr ) {
2042- LLAMA_LOG_ERROR (" invalid sampler or context in llama_sampler_dry_set_seq_breakers_c" );
2043- return ;
2044- }
2045-
2046- if (seq_breakers == nullptr || num_breakers <= 0 ) {
2047- LLAMA_LOG_ERROR (" invalid sequence breakers array or count in llama_sampler_dry_set_seq_breakers_c" );
2048- return ;
2049- }
2009+ return new llama_sampler {
2010+ /* .iface = */ &llama_sampler_dry_i,
2011+ /* .ctx = */ new llama_sampler_dry {
2012+ /* .model = */ model,
2013+ /* .total_context_size = */ context_size,
2014+ /* .dry_multiplier = */ dry_multiplier,
2015+ /* .dry_base = */ dry_base,
2016+ /* .dry_allowed_length = */ dry_allowed_length,
2017+ /* .dry_penalty_last_n = */ dry_penalty_last_n,
2018+ /* .dry_processed_breakers = */ std::move (processed_breakers),
2019+ /* .dry_repeat_count = */ std::vector<int >(effective_dry_penalty_last_n, 0 ),
2020+ /* .dry_max_token_repeat = */ {},
2021+ /* .last_tokens = */ ring_buffer<llama_token>(effective_dry_penalty_last_n),
2022+ },
2023+ };
2024+ }
20502025
2026+ // overloaded wrapper meant for C-interface
2027+ struct llama_sampler * llama_sampler_init_dry (const struct llama_model * model, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char ** seq_breakers, int num_breakers) {
20512028 std::vector<std::string> cpp_seq_breakers;
2052- for (int i = 0 ; i < num_breakers; ++i) {
2053- if (seq_breakers[i] == nullptr ) {
2054- LLAMA_LOG_WARN (" skipping null sequence breaker at index %d" , i);
2055- continue ;
2056- }
2057- if (std::strlen (seq_breakers[i]) == 0 ) {
2058- LLAMA_LOG_WARN (" skipping empty sequence breaker at index %d" , i);
2059- continue ;
2060- }
2061- cpp_seq_breakers.push_back (std::string (seq_breakers[i]));
2062- }
20632029
2064- if (cpp_seq_breakers.empty ()) {
2065- LLAMA_LOG_WARN (" no valid sequence breakers found in llama_sampler_dry_set_seq_breakers_c" );
2066- return ;
2030+ if (seq_breakers != nullptr && num_breakers > 0 ) {
2031+ for (int i = 0 ; i < num_breakers; ++i) {
2032+ if (seq_breakers[i] != nullptr && std::strlen (seq_breakers[i]) > 0 ) {
2033+ cpp_seq_breakers.push_back (std::string (seq_breakers[i]));
2034+ } else {
2035+ LLAMA_LOG_WARN (" skipping null or empty sequence breaker at index %d" , i);
2036+ }
2037+ }
20672038 }
20682039
2069- llama_sampler_dry_set_seq_breakers (smpl , cpp_seq_breakers);
2040+ return llama_sampler_init_dry (model, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n , cpp_seq_breakers);
20702041}
20712042
2072- // For use in test-sampling.cpp
2073- void llama_sampler_dry_set_seq_breakers_as_tokens (struct llama_sampler * smpl, const std::vector<std::vector<llama_token>>& seq_breakers) {
2074- if (smpl == nullptr || smpl->ctx == nullptr ) {
2075- LLAMA_LOG_ERROR (" invalid sampler or context in llama_sampler_dry_set_seq_breakers_as_tokens" );
2076- return ;
2043+ // overloaded wrapper meant for test-sampling.cpp
2044+ struct llama_sampler * llama_sampler_init_dry (const struct llama_model * model, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
2045+ auto * result = llama_sampler_init_dry (model, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL , 0 );
2046+
2047+ if (result == nullptr ) {
2048+ return nullptr ;
20772049 }
20782050
2079- auto * ctx = (llama_sampler_dry *) smpl->ctx ;
2080- ctx->dry_processed_breakers .clear ();
2051+ auto * ctx = (llama_sampler_dry *) result->ctx ;
20812052
2053+ // Process the token-based sequence breakers
2054+ ctx->dry_processed_breakers .clear ();
20822055 if (seq_breakers.empty ()) {
2083- LLAMA_LOG_WARN (" empty sequence breakers list in llama_sampler_dry_set_seq_breakers_as_tokens" );
2084- return ;
2085- }
2086-
2087- for (const auto & breaker : seq_breakers) {
2088- if (breaker.empty ()) {
2089- LLAMA_LOG_WARN (" skipping empty token sequence" );
2090- continue ;
2056+ LLAMA_LOG_WARN (" empty sequence breakers list in llama_sampler_init_dry" );
2057+ } else {
2058+ for (const auto & breaker : seq_breakers) {
2059+ if (breaker.empty ()) {
2060+ LLAMA_LOG_WARN (" skipping empty token sequence" );
2061+ continue ;
2062+ }
2063+ llama_token head_token = breaker[0 ];
2064+ std::vector<llama_token> tail_tokens (breaker.begin () + 1 , breaker.end ());
2065+ ctx->dry_processed_breakers .emplace (head_token, std::move (tail_tokens));
20912066 }
20922067
2093- llama_token head_token = breaker[ 0 ];
2094- std::vector<llama_token> tail_tokens (breaker. begin () + 1 , breaker. end () );
2095- ctx-> dry_processed_breakers . emplace (head_token, tail_tokens);
2068+ if (ctx-> dry_processed_breakers . empty ()) {
2069+ LLAMA_LOG_WARN ( " no valid sequence breakers processed in llama_sampler_init_dry " );
2070+ }
20962071 }
20972072
2098- if (ctx->dry_processed_breakers .empty ()) {
2099- LLAMA_LOG_WARN (" no valid sequence breakers processed in llama_sampler_dry_set_seq_breakers_as_tokens" );
2100- }
2073+ return result;
21012074}
21022075
21032076// logit-bias
0 commit comments