Skip to content

Commit dc408bb

Browse files
committed
DRY: Fixed crash issue due to DRY being in chain but uninitialized
1 parent 1303893 commit dc408bb

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

src/llama-sampling.cpp

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,7 +1747,7 @@ static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/
17471747

17481748
static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
17491749
auto * ctx = (llama_sampler_dry *) smpl->ctx;
1750-
if (ctx->dry_penalty_last_n == 0) {
1750+
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
17511751
return;
17521752
}
17531753

@@ -1758,7 +1758,7 @@ static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token to
17581758
static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
17591759
auto * ctx = (llama_sampler_dry *) smpl->ctx;
17601760

1761-
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f) {
1761+
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
17621762
return;
17631763
}
17641764

@@ -1999,18 +1999,15 @@ static struct llama_sampler_i llama_sampler_dry_i = {
19991999
};
20002000

20012001
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
2002-
if (dry_multiplier == 0.0f || dry_base < 1.0f) {
2003-
return nullptr;
2004-
}
2005-
20062002
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
2007-
2008-
// Process sequence breakers
20092003
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
20102004
const int MAX_CHAR_LEN = 40;
20112005
const int MAX_SEQ_LEN = 20;
20122006

2013-
if (seq_breakers != nullptr && num_breakers > 0) {
2007+
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
2008+
2009+
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
2010+
// Process sequence breakers
20142011
for (size_t i = 0; i < num_breakers; ++i) {
20152012
if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
20162013
LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
@@ -2041,9 +2038,9 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
20412038
/* .dry_allowed_length = */ dry_allowed_length,
20422039
/* .dry_penalty_last_n = */ dry_penalty_last_n,
20432040
/* .dry_processed_breakers = */ std::move(processed_breakers),
2044-
/* .dry_repeat_count = */ std::vector<int>(effective_dry_penalty_last_n, 0),
2041+
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
20452042
/* .dry_max_token_repeat = */ {},
2046-
/* .last_tokens = */ ring_buffer<llama_token>(effective_dry_penalty_last_n),
2043+
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
20472044
},
20482045
};
20492046
}
@@ -2052,11 +2049,6 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
20522049
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
20532050
llama_vocab dummy_vocab;
20542051
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
2055-
2056-
if (result == nullptr) {
2057-
return nullptr;
2058-
}
2059-
20602052
auto * ctx = (llama_sampler_dry *) result->ctx;
20612053

20622054
// Process the token-based sequence breakers

0 commit comments

Comments
 (0)