@@ -1747,7 +1747,7 @@ static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/
17471747
17481748static void llama_sampler_dry_accept (struct llama_sampler * smpl, llama_token token) {
17491749 auto * ctx = (llama_sampler_dry *) smpl->ctx ;
1750- if (ctx->dry_penalty_last_n == 0 ) {
1750+ if (ctx->dry_multiplier == 0 . 0f || ctx-> dry_base < 1 . 0f || ctx-> dry_penalty_last_n == 0 ) {
17511751 return ;
17521752 }
17531753
@@ -1758,7 +1758,7 @@ static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token to
17581758static void llama_sampler_dry_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
17591759 auto * ctx = (llama_sampler_dry *) smpl->ctx ;
17601760
1761- if (ctx->dry_multiplier == 0 .0f || ctx->dry_base < 1 .0f ) {
1761+ if (ctx->dry_multiplier == 0 .0f || ctx->dry_base < 1 .0f || ctx-> dry_penalty_last_n == 0 ) {
17621762 return ;
17631763 }
17641764
@@ -1999,18 +1999,15 @@ static struct llama_sampler_i llama_sampler_dry_i = {
19991999};
20002000
20012001struct llama_sampler * llama_sampler_init_dry_impl (const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char ** seq_breakers, size_t num_breakers) {
2002- if (dry_multiplier == 0 .0f || dry_base < 1 .0f ) {
2003- return nullptr ;
2004- }
2005-
20062002 int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1 ) ? context_size : std::max (dry_penalty_last_n, 0 );
2007-
2008- // Process sequence breakers
20092003 std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
20102004 const int MAX_CHAR_LEN = 40 ;
20112005 const int MAX_SEQ_LEN = 20 ;
20122006
2013- if (seq_breakers != nullptr && num_breakers > 0 ) {
2007+ const bool dry_enabled = (dry_multiplier != 0 .0f && dry_base >= 1 .0f && dry_penalty_last_n != 0 );
2008+
2009+ if (dry_enabled && seq_breakers != nullptr && num_breakers > 0 ) {
2010+ // Process sequence breakers
20142011 for (size_t i = 0 ; i < num_breakers; ++i) {
20152012 if (seq_breakers[i] == nullptr || std::strlen (seq_breakers[i]) == 0 ) {
20162013 LLAMA_LOG_WARN (" skipping null or empty DRY sequence breaker at index %zu\n " , i);
@@ -2041,9 +2038,9 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
20412038 /* .dry_allowed_length = */ dry_allowed_length,
20422039 /* .dry_penalty_last_n = */ dry_penalty_last_n,
20432040 /* .dry_processed_breakers = */ std::move (processed_breakers),
2044- /* .dry_repeat_count = */ std::vector<int >(effective_dry_penalty_last_n, 0 ),
2041+ /* .dry_repeat_count = */ dry_enabled ? std::vector<int >(effective_dry_penalty_last_n, 0 ) : std::vector< int >{} ,
20452042 /* .dry_max_token_repeat = */ {},
2046- /* .last_tokens = */ ring_buffer<llama_token>(effective_dry_penalty_last_n),
2043+ /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>( 0 ),
20472044 },
20482045 };
20492046}
@@ -2052,11 +2049,6 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
20522049struct llama_sampler * llama_sampler_init_dry_testing (int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
20532050 llama_vocab dummy_vocab;
20542051 auto * result = llama_sampler_init_dry_impl (dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL , 0 );
2055-
2056- if (result == nullptr ) {
2057- return nullptr ;
2058- }
2059-
20602052 auto * ctx = (llama_sampler_dry *) result->ctx ;
20612053
20622054 // Process the token-based sequence breakers
0 commit comments