Skip to content

Commit 06fe528

Browse files
committed
sample: maintain token count in penalty sampler context
1 parent a89f75e commit 06fe528

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

src/llama-sampling.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,9 @@ struct llama_sampler_penalties {
13861386
const bool ignore_eos;
13871387

13881388
ring_buffer<llama_token> prev;
1389+
1390+
// Frequency map to count occurrences of each token in last_tokens
1391+
std::unordered_map<llama_token, size_t> token_count;
13891392
};
13901393

13911394
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
@@ -1398,7 +1401,13 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
13981401
return;
13991402
}
14001403

1404+
if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
1405+
if (--ctx->token_count[ctx->prev.front()] == 0) {
1406+
ctx->token_count.erase(ctx->prev.front());
1407+
}
1408+
}
14011409
ctx->prev.push_back(token);
1410+
ctx->token_count[token]++;
14021411
}
14031412

14041413
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -1450,23 +1459,14 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
14501459
}
14511460
}
14521461

1453-
// Create a frequency map to count occurrences of each token in last_tokens
1454-
// TODO: optimize this by maintaining the token count in the sampler context
1455-
using llama_token_cnt = std::unordered_map<llama_token, int>;
1456-
llama_token_cnt token_count;
1457-
1458-
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1459-
token_count[ctx->prev.rat(i)]++;
1460-
}
1461-
14621462
// Apply frequency and presence penalties to the cur_p
14631463
for (size_t i = 0; i < cur_p->size; ++i) {
1464-
const auto token_iter = token_count.find(cur_p->data[i].id);
1465-
if (token_iter == token_count.end()) {
1464+
const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
1465+
if (token_iter == ctx->token_count.end()) {
14661466
continue;
14671467
}
14681468

1469-
const int count = token_iter->second;
1469+
const size_t count = token_iter->second;
14701470

14711471
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
14721472
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
@@ -1561,6 +1561,7 @@ struct llama_sampler * llama_sampler_init_penalties(
15611561
/* .penalize_nl = */ penalize_nl,
15621562
/* .ignore_eos = */ ignore_eos,
15631563
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1564+
/* .token_count = */ std::unordered_map<llama_token, size_t>(),
15641565
},
15651566
};
15661567
}

0 commit comments

Comments
 (0)