Skip to content

Commit 5169d1e

Browse files
committed
DRY: Removing sequence breaker functions
1 parent 2ef3cc9 commit 5169d1e

File tree

4 files changed

+71
-120
lines changed

4 files changed

+71
-120
lines changed

common/sampling.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
#include <cmath>
88
#include <unordered_map>
99

10-
extern void llama_sampler_dry_set_seq_breakers(struct llama_sampler * smpl, const std::vector<std::string>& seq_breakers);
10+
extern struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, int32_t context_size, float dry_multiplier, float dry_base,
11+
int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::string>& seq_breakers);
1112

1213
// the ring buffer works similarly to std::deque, but with a fixed capacity
1314
// TODO: deduplicate with llama-impl.h
@@ -186,11 +187,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
186187
for (const auto & cnstr : params.samplers) {
187188
switch (cnstr) {
188189
case COMMON_SAMPLER_TYPE_DRY:
189-
dry_sampler = llama_sampler_init_dry(model, context_size, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n);
190-
if (dry_sampler != nullptr) {
191-
llama_sampler_dry_set_seq_breakers(dry_sampler, params.dry_sequence_breakers);
192-
llama_sampler_chain_add(result->chain, dry_sampler);
193-
}
190+
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, context_size, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, params.dry_sequence_breakers));
194191
break;
195192
case COMMON_SAMPLER_TYPE_TOP_K:
196193
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
@@ -244,11 +241,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
244241
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
245242
}
246243

247-
// // If DRY sampler wasn't added to the chain, free it
248-
// if (dry_sampler) {
249-
// llama_sampler_free(dry_sampler);
250-
// }
251-
252244
return result;
253245
}
254246

include/llama.h

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,17 +1137,14 @@ extern "C" {
11371137
bool penalize_nl, // consider newlines as a repeatable token
11381138
bool ignore_eos); // ignore the end-of-sequence token
11391139

1140-
/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1141-
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
1142-
const struct llama_model * model,
1143-
int32_t context_size,
1144-
float dry_multiplier,
1145-
float dry_base,
1146-
int32_t dry_allowed_length,
1147-
int32_t dry_penalty_last_n);
1148-
1149-
LLAMA_API void llama_sampler_dry_set_seq_breakers_c(
1150-
struct llama_sampler * smpl,
1140+
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
1141+
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
1142+
const struct llama_model * model,
1143+
int32_t context_size,
1144+
float dry_multiplier,
1145+
float dry_base,
1146+
int32_t dry_allowed_length,
1147+
int32_t dry_penalty_last_n,
11511148
const char ** seq_breakers,
11521149
int num_breakers);
11531150

src/llama-sampling.cpp

Lines changed: 58 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1679,7 +1679,7 @@ struct llama_sampler_dry {
16791679
};
16801680

16811681
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1682-
static void GetOverlappingTokenSequences(const struct llama_model * model, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
1682+
static void get_overlapping_token_sequences(const struct llama_model * model, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
16831683
const int n_vocab = llama_n_vocab(model);
16841684
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
16851685
std::string word = llama_detokenize(model, {token_id}, true);
@@ -1721,8 +1721,6 @@ static void GetOverlappingTokenSequences(const struct llama_model * model, const
17211721
}
17221722
}
17231723

1724-
1725-
17261724
static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
17271725
return "dry";
17281726
}
@@ -1952,9 +1950,10 @@ static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
19521950

19531951
static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
19541952
const auto * ctx = (llama_sampler_dry *) smpl->ctx;
1955-
auto * result = llama_sampler_init_dry(ctx->model, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n);
19561953

1957-
// copy the state
1954+
auto * result = llama_sampler_init_dry(ctx->model, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
1955+
1956+
// Copy the state, including the processed breakers
19581957
{
19591958
auto * result_ctx = (llama_sampler_dry *) result->ctx;
19601959
result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
@@ -1979,44 +1978,16 @@ static struct llama_sampler_i llama_sampler_dry_i = {
19791978
/* .free = */ llama_sampler_dry_free,
19801979
};
19811980

1982-
struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n) {
1981+
struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::string>& seq_breakers) {
19831982
if (dry_multiplier < 0 || dry_base <= 0 || dry_allowed_length < 0) {
19841983
return nullptr;
19851984
}
19861985

19871986
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
19881987

1989-
return new llama_sampler {
1990-
/* .iface = */ &llama_sampler_dry_i,
1991-
/* .ctx = */ new llama_sampler_dry {
1992-
/* .model = */ model,
1993-
/* .total_context_size = */ context_size,
1994-
/* .dry_multiplier = */ dry_multiplier,
1995-
/* .dry_base = */ dry_base,
1996-
/* .dry_allowed_length = */ dry_allowed_length,
1997-
/* .dry_penalty_last_n = */ dry_penalty_last_n,
1998-
/* .dry_processed_breakers = */ {},
1999-
/* .dry_repeat_count = */ std::vector<int>(effective_dry_penalty_last_n, 0),
2000-
/* .dry_max_token_repeat = */ {},
2001-
/* .last_tokens = */ ring_buffer<llama_token>(effective_dry_penalty_last_n),
2002-
},
2003-
};
2004-
}
2005-
2006-
void llama_sampler_dry_set_seq_breakers(struct llama_sampler * smpl, const std::vector<std::string>& seq_breakers) {
2007-
if (smpl == nullptr || smpl->ctx == nullptr) {
2008-
LLAMA_LOG_ERROR("invalid sampler or context in llama_sampler_dry_set_seq_breakers");
2009-
return;
2010-
}
2011-
2012-
auto * ctx = (llama_sampler_dry *) smpl->ctx;
2013-
ctx->dry_processed_breakers.clear();
2014-
2015-
if (seq_breakers.empty()) {
2016-
LLAMA_LOG_WARN("empty sequence breakers list in llama_sampler_dry_set_seq_breakers");
2017-
return;
2018-
}
1988+
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
20191989

1990+
// Process sequence breakers
20201991
const int MAX_CHAR_LEN = 40;
20211992
const int MAX_SEQ_LEN = 20;
20221993

@@ -2032,72 +2003,74 @@ void llama_sampler_dry_set_seq_breakers(struct llama_sampler * smpl, const std::
20322003
trimmed_break.resize(MAX_CHAR_LEN);
20332004
}
20342005

2035-
GetOverlappingTokenSequences(ctx->model, trimmed_break, ctx->dry_processed_breakers, MAX_SEQ_LEN);
2006+
get_overlapping_token_sequences(model, trimmed_break, processed_breakers, MAX_SEQ_LEN);
20362007
}
2037-
}
20382008

2039-
// For C-interface
2040-
void llama_sampler_dry_set_seq_breakers_c(struct llama_sampler * smpl, const char **seq_breakers, int num_breakers) {
2041-
if (smpl == nullptr || smpl->ctx == nullptr) {
2042-
LLAMA_LOG_ERROR("invalid sampler or context in llama_sampler_dry_set_seq_breakers_c");
2043-
return;
2044-
}
2045-
2046-
if (seq_breakers == nullptr || num_breakers <= 0) {
2047-
LLAMA_LOG_ERROR("invalid sequence breakers array or count in llama_sampler_dry_set_seq_breakers_c");
2048-
return;
2049-
}
2009+
return new llama_sampler {
2010+
/* .iface = */ &llama_sampler_dry_i,
2011+
/* .ctx = */ new llama_sampler_dry {
2012+
/* .model = */ model,
2013+
/* .total_context_size = */ context_size,
2014+
/* .dry_multiplier = */ dry_multiplier,
2015+
/* .dry_base = */ dry_base,
2016+
/* .dry_allowed_length = */ dry_allowed_length,
2017+
/* .dry_penalty_last_n = */ dry_penalty_last_n,
2018+
/* .dry_processed_breakers = */ std::move(processed_breakers),
2019+
/* .dry_repeat_count = */ std::vector<int>(effective_dry_penalty_last_n, 0),
2020+
/* .dry_max_token_repeat = */ {},
2021+
/* .last_tokens = */ ring_buffer<llama_token>(effective_dry_penalty_last_n),
2022+
},
2023+
};
2024+
}
20502025

2026+
// overloaded wrapper meant for C-interface
2027+
struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, int num_breakers) {
20512028
std::vector<std::string> cpp_seq_breakers;
2052-
for (int i = 0; i < num_breakers; ++i) {
2053-
if (seq_breakers[i] == nullptr) {
2054-
LLAMA_LOG_WARN("skipping null sequence breaker at index %d", i);
2055-
continue;
2056-
}
2057-
if (std::strlen(seq_breakers[i]) == 0) {
2058-
LLAMA_LOG_WARN("skipping empty sequence breaker at index %d", i);
2059-
continue;
2060-
}
2061-
cpp_seq_breakers.push_back(std::string(seq_breakers[i]));
2062-
}
20632029

2064-
if (cpp_seq_breakers.empty()) {
2065-
LLAMA_LOG_WARN("no valid sequence breakers found in llama_sampler_dry_set_seq_breakers_c");
2066-
return;
2030+
if (seq_breakers != nullptr && num_breakers > 0) {
2031+
for (int i = 0; i < num_breakers; ++i) {
2032+
if (seq_breakers[i] != nullptr && std::strlen(seq_breakers[i]) > 0) {
2033+
cpp_seq_breakers.push_back(std::string(seq_breakers[i]));
2034+
} else {
2035+
LLAMA_LOG_WARN("skipping null or empty sequence breaker at index %d", i);
2036+
}
2037+
}
20672038
}
20682039

2069-
llama_sampler_dry_set_seq_breakers(smpl, cpp_seq_breakers);
2040+
return llama_sampler_init_dry(model, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, cpp_seq_breakers);
20702041
}
20712042

2072-
// For use in test-sampling.cpp
2073-
void llama_sampler_dry_set_seq_breakers_as_tokens(struct llama_sampler * smpl, const std::vector<std::vector<llama_token>>& seq_breakers) {
2074-
if (smpl == nullptr || smpl->ctx == nullptr) {
2075-
LLAMA_LOG_ERROR("invalid sampler or context in llama_sampler_dry_set_seq_breakers_as_tokens");
2076-
return;
2043+
// overloaded wrapper meant for test-sampling.cpp
2044+
struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
2045+
auto * result = llama_sampler_init_dry(model, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
2046+
2047+
if (result == nullptr) {
2048+
return nullptr;
20772049
}
20782050

2079-
auto * ctx = (llama_sampler_dry *) smpl->ctx;
2080-
ctx->dry_processed_breakers.clear();
2051+
auto * ctx = (llama_sampler_dry *) result->ctx;
20812052

2053+
// Process the token-based sequence breakers
2054+
ctx->dry_processed_breakers.clear();
20822055
if (seq_breakers.empty()) {
2083-
LLAMA_LOG_WARN("empty sequence breakers list in llama_sampler_dry_set_seq_breakers_as_tokens");
2084-
return;
2085-
}
2086-
2087-
for (const auto& breaker : seq_breakers) {
2088-
if (breaker.empty()) {
2089-
LLAMA_LOG_WARN("skipping empty token sequence");
2090-
continue;
2056+
LLAMA_LOG_WARN("empty sequence breakers list in llama_sampler_init_dry");
2057+
} else {
2058+
for (const auto& breaker : seq_breakers) {
2059+
if (breaker.empty()) {
2060+
LLAMA_LOG_WARN("skipping empty token sequence");
2061+
continue;
2062+
}
2063+
llama_token head_token = breaker[0];
2064+
std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
2065+
ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
20912066
}
20922067

2093-
llama_token head_token = breaker[0];
2094-
std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
2095-
ctx->dry_processed_breakers.emplace(head_token, tail_tokens);
2068+
if (ctx->dry_processed_breakers.empty()) {
2069+
LLAMA_LOG_WARN("no valid sequence breakers processed in llama_sampler_init_dry");
2070+
}
20962071
}
20972072

2098-
if (ctx->dry_processed_breakers.empty()) {
2099-
LLAMA_LOG_WARN("no valid sequence breakers processed in llama_sampler_dry_set_seq_breakers_as_tokens");
2100-
}
2073+
return result;
21012074
}
21022075

21032076
// logit-bias

tests/test-sampling.cpp

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include <string>
1111
#include <vector>
1212

13-
extern void llama_sampler_dry_set_seq_breakers_as_tokens(struct llama_sampler * smpl, const std::vector<std::vector<llama_token>>& seq_breakers);
13+
extern struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers);
1414

1515
static void dump(const llama_token_data_array * cur_p) {
1616
for (size_t i = 0; i < cur_p->size; i++) {
@@ -211,18 +211,7 @@ static void test_dry(
211211
const int32_t context_size = 1024;
212212
struct llama_model * model = nullptr;
213213

214-
struct llama_sampler * sampler = llama_sampler_init_dry(
215-
model,
216-
context_size,
217-
dry_multiplier,
218-
dry_base,
219-
dry_allowed_length,
220-
dry_penalty_last_n
221-
);
222-
223-
if (!seq_breakers.empty()) {
224-
llama_sampler_dry_set_seq_breakers_as_tokens(sampler, seq_breakers);
225-
}
214+
struct llama_sampler * sampler = llama_sampler_init_dry(model, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers);
226215

227216
for (size_t i = 0; i < last_tokens.size(); i++) {
228217
llama_sampler_accept(sampler, last_tokens[i]);

0 commit comments

Comments
 (0)