Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,9 @@ struct llama_sampler_penalties {
const bool ignore_eos;

ring_buffer<llama_token> prev;

// Frequency map to count occurrences of each token in last_tokens
std::unordered_map<llama_token, size_t> token_count;
};

static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
Expand All @@ -1398,7 +1401,14 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
return;
}

if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
assert(ctx->token_count.at(ctx->prev.front()) > 0);
if (--ctx->token_count[ctx->prev.front()] == 0) {
ctx->token_count.erase(ctx->prev.front());
}
}
ctx->prev.push_back(token);
ctx->token_count[token]++;
}

static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
Expand Down Expand Up @@ -1450,23 +1460,14 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
}
}

// Create a frequency map to count occurrences of each token in last_tokens
// TODO: optimize this by maintaining the token count in the sampler context
using llama_token_cnt = std::unordered_map<llama_token, int>;
llama_token_cnt token_count;

for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
token_count[ctx->prev.rat(i)]++;
}

// Apply frequency and presence penalties to the cur_p
for (size_t i = 0; i < cur_p->size; ++i) {
const auto token_iter = token_count.find(cur_p->data[i].id);
if (token_iter == token_count.end()) {
const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
if (token_iter == ctx->token_count.end()) {
continue;
}

const int count = token_iter->second;
const size_t count = token_iter->second;

// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
Expand All @@ -1490,6 +1491,7 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
ctx->prev.clear();
ctx->token_count.clear();
}

static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
Expand Down Expand Up @@ -1561,6 +1563,7 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .penalize_nl = */ penalize_nl,
/* .ignore_eos = */ ignore_eos,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
/* .token_count = */ std::unordered_map<llama_token, size_t>(),
},
};
}
Expand Down
Loading