@@ -1678,6 +1678,7 @@ struct llama_sampler_dry {
16781678 ring_buffer<llama_token> last_tokens;
16791679};
16801680
1681+ // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
16811682static void GetOverlappingTokenSequences (const struct llama_model * model, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1 ) {
16821683 const int n_vocab = llama_n_vocab (model);
16831684 for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
@@ -1735,6 +1736,7 @@ static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token to
17351736 ctx->last_tokens .push_back (token);
17361737}
17371738
1739+ // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
17381740static void llama_sampler_dry_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
17391741 auto * ctx = (llama_sampler_dry *) smpl->ctx ;
17401742
@@ -1752,7 +1754,28 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
17521754 ctx->dry_repeat_count .assign (last_n_repeat, 0 );
17531755 ctx->dry_max_token_repeat .clear ();
17541756
1755- // Step 1: Look for restart sequences
1757+ // Step 1: Look for restart sequences to limit the maximum repetition length.
1758+ // Work backwards through the context looking for any token that begins a restart sequence.
1759+ //
1760+ // The collection `restart_sequences` is a mapping from a "head" token to all "tail"
1761+ // sequences that together comprise a restart sequence. This allows us to quickly check
1762+ // whether each token is the head of a complete sequence. Most restart sequences are actually
1763+ // a single token, and for these the "tail" is an empty vector.
1764+ //
1765+ // If the token is a "head", test all restart sequences that begin with this token
1766+ // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
1767+ // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
1768+ // longest matching sequence (if any) is used to limit the maximum repetition length.
1769+ //
1770+ // Note that in the case case of a short sequence contained in a longer one, this might fail to
1771+ // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
1772+ // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
1773+ // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
1774+ //
1775+ // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
1776+ // have already clamped the maximum tail sequence length when generating `restart_sequences`.
1777+ // With clamping, this scan is O(N) in the context length.
1778+
17561779 int rep_limit = last_n_repeat;
17571780 for (int i = 0 ; i < last_n_repeat; ++i) {
17581781 llama_token token = ctx->last_tokens .rat (i);
@@ -1762,10 +1785,15 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
17621785 }
17631786 int longest_match = -1 ;
17641787 for (auto it = its.first ; it != its.second ; ++it) {
1788+ // Note that (*it) does not contain the head character, so seq_len will be
1789+ // the restart sequence length minus 1.
1790+ // In the common case of a single-token restart sequence, (*it) will be empty
1791+ // and we will trivially match.
17651792 int seq_len = (int )it->second .size ();
17661793 if (seq_len > longest_match && seq_len <= (int )i) {
17671794 bool match = true ;
17681795 for (int offset = 0 ; offset < seq_len; ++offset) {
1796+ // The -1 when indexing `last_tokens` is because we already matched the head.
17691797 if (it->second [offset] != ctx->last_tokens .rat (i - offset - 1 )) {
17701798 match = false ;
17711799 break ;
@@ -1777,6 +1805,8 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
17771805 }
17781806 }
17791807 if (longest_match >= 0 ) {
1808+ // We found a restart sequence starting `i` tokens from the end and continuing for
1809+ // `longest_match` tokens.
17801810 rep_limit = i - longest_match;
17811811 break ;
17821812 }
@@ -1785,13 +1815,35 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
17851815 return ;
17861816 }
17871817
1788- // Step 2: Z-algorithm implementation
1818+ // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
1819+ // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
1820+ // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
1821+ //
1822+ // This algorithm is not currently documented on Wikipedia, but there is a clear description here:
1823+ // https://ivanyu.me/blog/2014/10/15/z-algorithm/
1824+ //
1825+ // The code below is adapted from the public domain implementation by the same author here:
1826+ // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
1827+ //
1828+ // Example:
1829+ // Last N tokens: a b c c b c y a b c
1830+ // Repeat counts: 0 0 3 1 0 2 0 0 0 0
1831+ // ^
1832+ // This `3` means that the last three tokens of the context (a b c) also appear here.
1833+ //
1834+ // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
1835+ // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
1836+ // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
1837+ // ensure that the inner while loops only examine each token in the context once as the outer
1838+ // for loop iterates over the context.
1839+
17891840 {
17901841 const int last = last_n_repeat - 1 ;
17911842 int rt = 0 , lt = 0 ;
17921843
17931844 for (int k = 1 ; k < last_n_repeat; ++k) {
17941845 if (k > rt) {
1846+ // If k is outside the current Z-box, do naive computation.
17951847 int n = 0 ;
17961848 while (n + k < last_n_repeat && ctx->last_tokens .rat (n) == ctx->last_tokens .rat (n+k)) {
17971849 ++n;
@@ -1802,7 +1854,9 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
18021854 rt = k+n-1 ;
18031855 }
18041856 } else {
1805- int p = k - lt;
1857+ // If k is inside the current Z-box, consider two cases.
1858+
1859+ int p = k - lt; // Pair index.
18061860 int right_part_len = rt - k + 1 ;
18071861
18081862 if (ctx->dry_repeat_count [last - p] < right_part_len) {
@@ -1823,19 +1877,37 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
18231877 }
18241878 }
18251879
1826- // Step 3: Find maximum repeat length for each token
1880+ // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
1881+ // that would be generated by emitting each new token that would extend a sequence.
1882+ //
1883+ // Following the same example as above:
1884+ // Last N tokens: a b c c b c y a b c
1885+ // Repeat counts: 0 0 3 1 0 2 0 0 0 0
1886+ //
1887+ // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
1888+ // c: 3 -> 4 (from `a b c` to `a b c c`)
1889+ // b: 1 -> 2 (from `c` to `c b`)
1890+ // y: 2 -> 3 (from `b c` to `b c y`)
1891+
18271892 for (int i = 0 ; i < last_n_repeat - 1 ; ++i) {
18281893 int repeat_len = ctx->dry_repeat_count [i];
18291894 if (repeat_len >= ctx->dry_allowed_length ) {
1895+ // This token ends a repeat, so the next token would continue one.
1896+ // By convention, the value of `repeat_len` only includes the tokens currently
1897+ // in the context, not the new token that would be added.
18301898 llama_token token = ctx->last_tokens .rat (last_n_repeat - 2 - i);
1899+ // Track the maximum sequence ending in this token.
18311900 const auto & it = ctx->dry_max_token_repeat .find (token);
18321901 if (it == ctx->dry_max_token_repeat .end () || it->second < repeat_len) {
18331902 ctx->dry_max_token_repeat [token] = repeat_len;
18341903 }
18351904 }
18361905 }
18371906
1838- // Step 4: Apply penalties
1907+ // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
1908+
1909+ // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
1910+ // Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
18391911 const float FLOAT_MAX_LOG = 88 .7228391f ;
18401912 int max_exponent = 0 ;
18411913 if (ctx->dry_base > 1 .000001f ) {
0 commit comments