Skip to content

Commit c210cba

Browse files
wwoodsTMpi6am
andcommitted
DRY: WIP, restoring Koboldcpp comments, adding attribution in llama.h and llama-sampling.cpp
Co-authored-by: pi6am <[email protected]>
1 parent a50603d commit c210cba

File tree

3 files changed

+85
-7
lines changed

3 files changed

+85
-7
lines changed

examples/server/server.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,13 @@ struct server_context {
898898
if (dry_sequence_breakers->is_array()) {
899899
slot.sparams.dry_sequence_breakers = dry_sequence_breakers->get<std::vector<std::string>>();
900900
} else if (dry_sequence_breakers->is_string()) {
901-
slot.sparams.dry_sequence_breakers = json::parse(dry_sequence_breakers->get<std::string>()).get<std::vector<std::string>>();
901+
std::string dry_sequence_breakers_str = dry_sequence_breakers->get<std::string>();
902+
903+
if (dry_sequence_breakers_str.empty() || dry_sequence_breakers_str[0] != '[') {
904+
dry_sequence_breakers_str = "[" + dry_sequence_breakers_str + "]";
905+
}
906+
907+
slot.sparams.dry_sequence_breakers = json::parse(dry_sequence_breakers_str).get<std::vector<std::string>>();
902908
} else {
903909
send_error(task, "\"dry_sequence_breakers\": Expected an array of strings or a JSON-encoded array of strings.", ERROR_TYPE_INVALID_REQUEST);
904910
return false;

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1149,7 +1149,7 @@ extern "C" {
11491149
bool penalize_nl, // consider newlines as a repeatable token
11501150
bool ignore_eos); // ignore the end-of-sequence token
11511151

1152-
/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
1152+
/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
11531153
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
11541154
const struct llama_model * model,
11551155
int32_t context_size,

src/llama-sampling.cpp

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,6 +1678,7 @@ struct llama_sampler_dry {
16781678
ring_buffer<llama_token> last_tokens;
16791679
};
16801680

1681+
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
16811682
static void GetOverlappingTokenSequences(const struct llama_model * model, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
16821683
const int n_vocab = llama_n_vocab(model);
16831684
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
@@ -1735,6 +1736,7 @@ static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token to
17351736
ctx->last_tokens.push_back(token);
17361737
}
17371738

1739+
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
17381740
static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
17391741
auto * ctx = (llama_sampler_dry *) smpl->ctx;
17401742

@@ -1752,7 +1754,28 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
17521754
ctx->dry_repeat_count.assign(last_n_repeat, 0);
17531755
ctx->dry_max_token_repeat.clear();
17541756

1755-
// Step 1: Look for restart sequences
1757+
// Step 1: Look for restart sequences to limit the maximum repetition length.
1758+
// Work backwards through the context looking for any token that begins a restart sequence.
1759+
//
1760+
// The collection `restart_sequences` is a mapping from a "head" token to all "tail"
1761+
// sequences that together comprise a restart sequence. This allows us to quickly check
1762+
// whether each token is the head of a complete sequence. Most restart sequences are actually
1763+
// a single token, and for these the "tail" is an empty vector.
1764+
//
1765+
// If the token is a "head", test all restart sequences that begin with this token
1766+
// (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
1767+
// 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
1768+
// longest matching sequence (if any) is used to limit the maximum repetition length.
1769+
//
1770+
// Note that in the case case of a short sequence contained in a longer one, this might fail to
1771+
// find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
1772+
// restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
1773+
// 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
1774+
//
1775+
// This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
1776+
// have already clamped the maximum tail sequence length when generating `restart_sequences`.
1777+
// With clamping, this scan is O(N) in the context length.
1778+
17561779
int rep_limit = last_n_repeat;
17571780
for (int i = 0; i < last_n_repeat; ++i) {
17581781
llama_token token = ctx->last_tokens.rat(i);
@@ -1762,10 +1785,15 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
17621785
}
17631786
int longest_match = -1;
17641787
for (auto it = its.first; it != its.second; ++it) {
1788+
// Note that (*it) does not contain the head character, so seq_len will be
1789+
// the restart sequence length minus 1.
1790+
// In the common case of a single-token restart sequence, (*it) will be empty
1791+
// and we will trivially match.
17651792
int seq_len = (int)it->second.size();
17661793
if (seq_len > longest_match && seq_len <= (int)i) {
17671794
bool match = true;
17681795
for (int offset = 0; offset < seq_len; ++offset) {
1796+
// The -1 when indexing `last_tokens` is because we already matched the head.
17691797
if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
17701798
match = false;
17711799
break;
@@ -1777,6 +1805,8 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
17771805
}
17781806
}
17791807
if (longest_match >= 0) {
1808+
// We found a restart sequence starting `i` tokens from the end and continuing for
1809+
// `longest_match` tokens.
17801810
rep_limit = i - longest_match;
17811811
break;
17821812
}
@@ -1785,13 +1815,35 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
17851815
return;
17861816
}
17871817

1788-
// Step 2: Z-algorithm implementation
1818+
// Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
1819+
// the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
1820+
// elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
1821+
//
1822+
// This algorithm is not currently documented on Wikipedia, but there is a clear description here:
1823+
// https://ivanyu.me/blog/2014/10/15/z-algorithm/
1824+
//
1825+
// The code below is adapted from the public domain implementation by the same author here:
1826+
// https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
1827+
//
1828+
// Example:
1829+
// Last N tokens: a b c c b c y a b c
1830+
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
1831+
// ^
1832+
// This `3` means that the last three tokens of the context (a b c) also appear here.
1833+
//
1834+
// This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
1835+
// for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
1836+
// repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
1837+
// ensure that the inner while loops only examine each token in the context once as the outer
1838+
// for loop iterates over the context.
1839+
17891840
{
17901841
const int last = last_n_repeat - 1;
17911842
int rt = 0, lt = 0;
17921843

17931844
for (int k = 1; k < last_n_repeat; ++k) {
17941845
if (k > rt) {
1846+
// If k is outside the current Z-box, do naive computation.
17951847
int n = 0;
17961848
while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
17971849
++n;
@@ -1802,7 +1854,9 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
18021854
rt = k+n-1;
18031855
}
18041856
} else {
1805-
int p = k - lt;
1857+
// If k is inside the current Z-box, consider two cases.
1858+
1859+
int p = k - lt; // Pair index.
18061860
int right_part_len = rt - k + 1;
18071861

18081862
if (ctx->dry_repeat_count[last - p] < right_part_len) {
@@ -1823,19 +1877,37 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
18231877
}
18241878
}
18251879

1826-
// Step 3: Find maximum repeat length for each token
1880+
// Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
1881+
// that would be generated by emitting each new token that would extend a sequence.
1882+
//
1883+
// Following the same example as above:
1884+
// Last N tokens: a b c c b c y a b c
1885+
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
1886+
//
1887+
// For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
1888+
// c: 3 -> 4 (from `a b c` to `a b c c`)
1889+
// b: 1 -> 2 (from `c` to `c b`)
1890+
// y: 2 -> 3 (from `b c` to `b c y`)
1891+
18271892
for (int i = 0; i < last_n_repeat - 1; ++i) {
18281893
int repeat_len = ctx->dry_repeat_count[i];
18291894
if (repeat_len >= ctx->dry_allowed_length) {
1895+
// This token ends a repeat, so the next token would continue one.
1896+
// By convention, the value of `repeat_len` only includes the tokens currently
1897+
// in the context, not the new token that would be added.
18301898
llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
1899+
// Track the maximum sequence ending in this token.
18311900
const auto& it = ctx->dry_max_token_repeat.find(token);
18321901
if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
18331902
ctx->dry_max_token_repeat[token] = repeat_len;
18341903
}
18351904
}
18361905
}
18371906

1838-
// Step 4: Apply penalties
1907+
// Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
1908+
1909+
// Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
1910+
// Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
18391911
const float FLOAT_MAX_LOG = 88.7228391f;
18401912
int max_exponent = 0;
18411913
if (ctx->dry_base > 1.000001f) {

0 commit comments

Comments
 (0)