1111#include < time.h>
1212#include < mutex>
1313#include < unordered_map>
14+ #include < unordered_set>
1415#include " model_adapter.h"
1516#include " otherarch.h"
1617#include " llama.h"
@@ -1188,34 +1189,22 @@ void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, float rep_pen_s
11881189 const int64_t t_start_sample_us = ggml_time_us ();
11891190
11901191 // Create a frequency map to count occurrences of each token in last_tokens
1191- std::unordered_map<llama_token, int > token_count_near;
1192- std::unordered_map<llama_token, int > token_count_far;
1193- for (size_t i = 0 ; i < last_n_repeat; ++i) {
1194- if ((i*2 ) >= last_n_repeat)
1195- {
1196- token_count_near[last_tokens[i]]++;
1197- }
1198- else
1199- {
1200- token_count_far[last_tokens[i]]++;
1201- }
1202- }
1203-
1192+ std::unordered_set<llama_token> tokens_near (last_tokens + last_n_repeat / 2 , last_tokens + last_n_repeat);
1193+ std::unordered_set<llama_token> tokens_far (last_tokens, last_tokens + last_n_repeat / 2 );
1194+
12041195 float rep_pen_reduced = rep_pen;
12051196 if (rep_pen_reduced>1 .0f )
12061197 {
12071198 rep_pen_reduced = 1 .0f + ((rep_pen-1 .0f )*rep_pen_slope);
12081199 }
12091200 for (size_t i = 0 ; i < candidates->size ; ++i) {
1210- const auto token_in_near = token_count_near.find (candidates->data [i].id );
1211- const auto token_in_far = token_count_far.find (candidates->data [i].id );
1212- bool in_near = (token_in_near != token_count_near.end ());
1213- bool in_far = (token_in_far != token_count_far.end ());
1214- if (!in_near && !in_far) {
1201+ const bool token_in_near = tokens_near.find (candidates->data [i].id ) != tokens_near.end ();
1202+ const bool token_in_far = tokens_far.find (candidates->data [i].id ) != tokens_far.end ();
1203+ if (!token_in_near && !token_in_far) {
12151204 continue ;
12161205 }
12171206
1218- float penalty = (in_near ?rep_pen:rep_pen_reduced);
1207+ float penalty = (token_in_near ?rep_pen:rep_pen_reduced);
12191208
12201209 // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
12211210 // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
@@ -1229,7 +1218,6 @@ void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, float rep_pen_s
12291218 }
12301219
12311220 candidates->sorted = false ;
1232-
12331221}
12341222
12351223void sample_top_p (llama_token_data_array * cur_p, float p, size_t min_keep) {
0 commit comments