Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ struct gpt_sampler_params {
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float penalty_repeat_sigmoid_growth = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
Expand Down
1 change: 1 addition & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
params.penalty_repeat,
params.penalty_freq,
params.penalty_present,
params.penalty_repeat_sigmoid_growth,
params.penalize_nl,
params.ignore_eos));

Expand Down
2 changes: 2 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,8 @@ node index.js

`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.

`repeat_penalty_sigmoid_growth`: Apply the sigmoid function to `repeat_penalty` within `repeat_last_n` range. The value of `1` means linear change in penalty from 1 to `repeat_penalty`. Higher values > 1 increase the difference in the resulting penalty between the first and the second half of the penalty range. Lower values < 1 change the resulting penalty slower in the middle of the range. Negative values will be changing the penalty in the same way, but from `repeat_penalty` to 1. Default: `0.0`, which is disabled.

`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.

`mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`
Expand Down
2 changes: 2 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ struct server_context {
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
slot.sparams.penalty_repeat_sigmoid_growth = json_value(data, "repeat_penalty_sigmoid_growth", default_sparams.penalty_repeat_sigmoid_growth);
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
Expand Down Expand Up @@ -1239,6 +1240,7 @@ struct server_context {
{"repeat_penalty", slot.sparams.penalty_repeat},
{"presence_penalty", slot.sparams.penalty_present},
{"frequency_penalty", slot.sparams.penalty_freq},
{"repeat_penalty_sigmoid_growth", slot.sparams.penalty_repeat_sigmoid_growth},
{"mirostat", slot.sparams.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta},
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,7 @@ extern "C" {
float penalty_repeat, // 1.0 = disabled
float penalty_freq, // 0.0 = disabled
float penalty_present, // 0.0 = disabled
float penalty_repeat_sigmoid_growth, // 0.0 = disabled
bool penalize_nl, // consider newlines as a repeatable token
bool ignore_eos); // ignore the end-of-sequence token

Expand Down
123 changes: 120 additions & 3 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,7 @@ struct llama_sampler_penalties {
const float penalty_repeat;
const float penalty_freq;
const float penalty_present;
const float penalty_repeat_sigmoid_growth;

const bool penalize_nl;
const bool ignore_eos;
Expand Down Expand Up @@ -1450,6 +1451,115 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
}
}

struct sigmoid {
protected:
bool enabled;
float growth;
bool use_mirrored;
const ring_buffer<llama_token> & last_tokens;
size_t last_tokens_size;
size_t penalty_last_n;
float token_x;
float y_min = 0;
float y_diff = 0;

inline float calc_sigmoid(float x) {
float y = 1 / (1 + exp((-x + 0.5) * growth));
return y;
}

inline float calc_sigmoid_inv_growth(float x) {
float y = 1 / (1 + exp((-x + 0.5) / growth));
return y;
}

// sigmoid mirrored by y=x
inline float calc_mirrored_sigmoid(float x) {
if ((x == 0 && growth > 0) || (x >= 1 && growth < 0)) {
return 0;
}
if ((x == 0 && growth < 0) || (x >= 1 && growth > 0)) {
return 1;
}
// the actual formula: y = 0.5 - log((1 - x) / x) / growth
// but we invert the growth to transform the initial (0;1) range to the (1;+inf) range
float inv_growth = 1 / growth;
float y = 0.5 - log((1 - x) / x) / inv_growth;
return y;
}

inline float calc_norm_coeff(float x) {
if (use_mirrored) {
float norm_x = (x + y_min) * y_diff; // normalize x within a range of the non-mirrored sigmoid's y
float y = calc_mirrored_sigmoid(norm_x);
return y;
}

float y = calc_sigmoid(x);
float norm_y = (y - y_min) / y_diff;
return norm_y;
}

static inline float apply_norm_coeff(float coeff, float penalty) {
float initial_diff = penalty - 1;
float result_diff = initial_diff * coeff;
return 1 + result_diff;
}

public:
explicit sigmoid(
float growth,
const ring_buffer<llama_token> & last_tokens,
size_t penalty_last_n
) :
enabled(growth != 0),
growth(growth),
use_mirrored(abs(growth) < 1),
last_tokens(last_tokens),
last_tokens_size(std::min(penalty_last_n, last_tokens.size())),
penalty_last_n(penalty_last_n),
token_x(1 / (float)penalty_last_n) {
if (!enabled) {
return;
}
float y1;
float y2;
if (use_mirrored) {
y1 = calc_sigmoid_inv_growth(0);
y2 = calc_sigmoid_inv_growth(1);
} else {
y1 = calc_sigmoid(0);
y2 = calc_sigmoid(1);
}
y_min = std::min(y1, y2);
float y_max = std::max(y1, y2);
y_diff = y_max - y_min;
}

inline float apply(float penalty, llama_token token) {
if (!enabled) {
return penalty;
}
// the position (from the end) within the penalty tokens array
size_t token_rindex = 0;
while (token_rindex < last_tokens_size) {
if (last_tokens.rat(token_rindex) == token) {
break; // must always break at some point, otherwise it's UB
}
token_rindex++;
}
// the position within the penalty range,
// it's 1-indexed, so the last token in the range will correspond to x=1
size_t token_pos = penalty_last_n - token_rindex;
float x = token_x * token_pos;
float coeff = calc_norm_coeff(x);
float resulting_penalty = apply_norm_coeff(coeff, penalty);
return resulting_penalty;
}
};

sigmoid penalty_repeat_sigmoid(ctx->penalty_repeat_sigmoid_growth, ctx->prev, ctx->penalty_last_n);

// Create a frequency map to count occurrences of each token in last_tokens
// TODO: optimize this by maintaining the token count in the sampler context
using llama_token_cnt = std::unordered_map<llama_token, int>;
Expand All @@ -1461,7 +1571,8 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok

// Apply frequency and presence penalties to the cur_p
for (size_t i = 0; i < cur_p->size; ++i) {
const auto token_iter = token_count.find(cur_p->data[i].id);
const auto token = cur_p->data[i].id;
const auto token_iter = token_count.find(token);
if (token_iter == token_count.end()) {
continue;
}
Expand All @@ -1470,11 +1581,14 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok

// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
float applied_penalty_repeat;
if (cur_p->data[i].logit <= 0) {
cur_p->data[i].logit *= ctx->penalty_repeat;
applied_penalty_repeat = ctx->penalty_repeat;
} else {
cur_p->data[i].logit /= ctx->penalty_repeat;
applied_penalty_repeat = 1 / ctx->penalty_repeat;
}
applied_penalty_repeat = penalty_repeat_sigmoid.apply(applied_penalty_repeat, token);
cur_p->data[i].logit *= applied_penalty_repeat;

cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
}
Expand Down Expand Up @@ -1502,6 +1616,7 @@ static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_s
ctx->penalty_repeat,
ctx->penalty_freq,
ctx->penalty_present,
ctx->penalty_repeat_sigmoid_growth,
ctx->penalize_nl,
ctx->ignore_eos);

Expand Down Expand Up @@ -1536,6 +1651,7 @@ struct llama_sampler * llama_sampler_init_penalties(
float penalty_repeat,
float penalty_freq,
float penalty_present,
float penalty_repeat_sigmoid_growth,
bool penalize_nl,
bool ignore_eos) {
if (linefeed_id == LLAMA_TOKEN_NULL) {
Expand All @@ -1558,6 +1674,7 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .penalty_repeat = */ penalty_repeat,
/* .penalty_freq = */ penalty_freq,
/* .penalty_present = */ penalty_present,
/* .penalty_repeat_sigmoid_growth = */ penalty_repeat_sigmoid_growth,
/* .penalize_nl = */ penalize_nl,
/* .ignore_eos = */ ignore_eos,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
Expand Down
20 changes: 12 additions & 8 deletions tests/test-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo

static void test_penalties(
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence, float penalty_repeat_sigmoid_growth
) {
GGML_ASSERT(probs.size() == expected_probs.size());

Expand All @@ -149,7 +149,7 @@ static void test_penalties(

llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };

auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, penalty_repeat_sigmoid_growth, false, false);

for (size_t i = 0; i < last_tokens.size(); i++) {
llama_sampler_accept(sampler, last_tokens[i]);
Expand Down Expand Up @@ -316,13 +316,17 @@ int main(void) {
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);

test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.0f);

test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 0.0f, 0.0f, 10.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.263353f, 0.263353f, 0.201890f, 0.153630f, 0.117775f}, 1.5f, 0.0f, 0.0f, 1.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.290188f, 0.246533f, 0.182452f, 0.140414f, 0.140414f}, 0.5f, 0.0f, 0.0f, -0.5f);

test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f, 0.0f);

test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
Expand Down
Loading