Skip to content

Commit 2740af3

Browse files
add top n sigma sampler from llama.cpp (#1384)
* Add N Sigma Sampler * update nsigma sampler chain * xtc position fix * remove stray newline --------- Co-authored-by: CasualAutopsy <[email protected]>
1 parent 5f74ee3 commit 2740af3

File tree

4 files changed

+61
-4
lines changed

4 files changed

+61
-4
lines changed

expose.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ struct generation_inputs
8282
const float min_p = 0.0f;
8383
const float typical_p = 0;
8484
const float tfs = 0;
85+
const float nsigma = -1.0f;
8586
const float rep_pen = 0;
8687
const int rep_pen_range = 0;
8788
const float rep_pen_slope = 1.0f;

gpttype_adapter.cpp

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,35 @@ void sampler_typical(llama_token_data_array * cur_p, float p, size_t min_keep) {
14281428
cur_p->sorted = false;
14291429
}
14301430

1431+
void sample_top_n_sigma(llama_token_data_array * cur_p, float nsigma) {
1432+
1433+
// find max logit and calculate mean
1434+
float nsigmax = cur_p->data[0].logit;
1435+
float logits_sum = 0;
1436+
for (size_t i = 0; i < cur_p->size; ++i) {
1437+
if (cur_p->data[i].logit > nsigmax) {
1438+
nsigmax = cur_p->data[i].logit;
1439+
}
1440+
logits_sum += cur_p->data[i].logit;
1441+
}
1442+
float nsigmean = logits_sum / cur_p->size;
1443+
1444+
// calculate standard deviation
1445+
float nsigacc = 0;
1446+
for (size_t i = 0; i < cur_p->size; ++i) {
1447+
nsigacc += pow(cur_p->data[i].logit - nsigmean, 2);
1448+
}
1449+
float nsigstd = sqrt(nsigacc / cur_p->size);
1450+
1451+
//apply mask
1452+
for (size_t i = 0; i < cur_p->size; ++i) {
1453+
if (cur_p->data[i].logit < nsigmax - (nsigma * nsigstd)) {
1454+
cur_p->data[i].logit -= 999.0f;
1455+
}
1456+
}
1457+
sample_softmax(cur_p);
1458+
}
1459+
14311460
void sample_entropy(llama_token_data_array * cur_p, float min_temp, float max_temp, float exponent_val, float smoothing_factor) {
14321461
// no need to do anything if there is only one (or zero) candidates
14331462
if (cur_p->size <= 1) {
@@ -1561,7 +1590,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
15611590

15621591
}
15631592

1564-
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
1593+
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float nsigma, float temp, std::mt19937 & rng,
15651594
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, float xtc_threshold, float xtc_probability,
15661595
const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor)
15671596
{
@@ -1584,8 +1613,10 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
15841613
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
15851614
}
15861615

1587-
//dry always first as logits cannot be resorted
1588-
sample_dry(n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p);
1616+
if (nsigma <= 0.0f){
1617+
//dry always first as logits cannot be resorted
1618+
sample_dry(n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p);
1619+
}
15891620

15901621
//prefilter to top 3k tokens for improved speed
15911622
sample_top_k(&candidates_p, 3000);
@@ -1605,6 +1636,25 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
16051636
id = sample_token_mirostat_v2(&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu);
16061637
}
16071638
}
1639+
else if (nsigma > 0.0f)
1640+
{
1641+
sample_top_k(&candidates_p, top_k);
1642+
if (dynatemp_range > 0) {
1643+
float dynatemp_min = temp - dynatemp_range;
1644+
float dynatemp_max = temp + dynatemp_range;
1645+
//do not allow negative values
1646+
dynatemp_min = dynatemp_min < 0 ? 0 : dynatemp_min;
1647+
dynatemp_max = dynatemp_max < 0 ? 0 : dynatemp_max;
1648+
dynatemp_exponent = dynatemp_exponent < 0 ? 0 : dynatemp_exponent;
1649+
sample_entropy(&candidates_p, dynatemp_min, dynatemp_max, dynatemp_exponent, smoothing_factor);
1650+
} else {
1651+
sample_temperature(&candidates_p, temp, smoothing_factor);
1652+
}
1653+
sample_top_n_sigma(&candidates_p, nsigma);
1654+
1655+
sample_xtc(&candidates_p, xtc_threshold, xtc_probability, rng);
1656+
id = sample_token(&candidates_p, rng);
1657+
}
16081658
else
16091659
{
16101660
for (int i = 0; i < sampler_order.size(); i++)
@@ -2999,6 +3049,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
29993049
kcpp_data->min_p = inputs.min_p;
30003050
kcpp_data->typical_p = inputs.typical_p;
30013051
kcpp_data->tfs_z = inputs.tfs;
3052+
kcpp_data->nsigma = inputs.nsigma;
30023053
kcpp_data->temp = inputs.temperature;
30033054
kcpp_data->repeat_last_n = inputs.rep_pen_range;
30043055
kcpp_data->rep_pen_slope = inputs.rep_pen_slope;
@@ -3529,6 +3580,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
35293580
const float presence_penalty = kcpp_data->presence_penalty;
35303581
const float typical_p = kcpp_data->typical_p;
35313582
const float tfs_z = kcpp_data->tfs_z;
3583+
const float nsigma = kcpp_data->nsigma;
35323584
const float dynatemp_range = kcpp_data->dynatemp_range;
35333585
const float dynatemp_exponent = kcpp_data->dynatemp_exponent;
35343586
const float smoothing_factor = kcpp_data->smoothing_factor;
@@ -3624,7 +3676,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
36243676
}
36253677

36263678
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, kcpp_data->rep_pen_slope, presence_penalty,
3627-
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
3679+
top_k, top_a, top_p, min_p, typical_p, tfs_z, nsigma, temp, rng,
36283680
kcpp_data->mirostat, kcpp_data->mirostat_tau, kcpp_data->mirostat_eta,
36293681
kcpp_data->dry_multiplier, kcpp_data->dry_base,
36303682
kcpp_data->dry_allowed_length, kcpp_data->dry_penalty_last_n, kcpp_data->xtc_threshold, kcpp_data->xtc_probability,

koboldcpp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ class generation_inputs(ctypes.Structure):
194194
("min_p", ctypes.c_float),
195195
("typical_p", ctypes.c_float),
196196
("tfs", ctypes.c_float),
197+
("nsigma", ctypes.c_float),
197198
("rep_pen", ctypes.c_float),
198199
("rep_pen_range", ctypes.c_int),
199200
("rep_pen_slope", ctypes.c_float),
@@ -1116,6 +1117,7 @@ def generate(genparams, stream_flag=False):
11161117
min_p = float(genparams.get('min_p', 0.0))
11171118
typical_p = float(genparams.get('typical', 1.0))
11181119
tfs = float(genparams.get('tfs', 1.0))
1120+
nsigma = float(genparams.get('nsigma', -1.0))
11191121
rep_pen = float(genparams.get('rep_pen', 1.0))
11201122
rep_pen_range = int(genparams.get('rep_pen_range', 320))
11211123
rep_pen_slope = float(genparams.get('rep_pen_slope', 1.0))
@@ -1182,6 +1184,7 @@ def generate(genparams, stream_flag=False):
11821184
inputs.min_p = min_p
11831185
inputs.typical_p = typical_p
11841186
inputs.tfs = tfs
1187+
inputs.nsigma = nsigma
11851188
inputs.rep_pen = rep_pen
11861189
inputs.rep_pen_range = rep_pen_range
11871190
inputs.rep_pen_slope = rep_pen_slope

otherarch/otherarch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ struct kcpp_params {
2929
float top_p = 0.95f; // 1.0 = disabled
3030
float min_p = 0.0f; // 0.0 = disabled
3131
float tfs_z = 1.00f; // 1.0 = disabled
32+
float nsigma = -1.00f; // -1.0 - disabled
3233
float typical_p = 1.00f; // 1.0 = disabled
3334
float temp = 0.80f; // 1.0 = disabled
3435
float smoothing_factor = 0.00f; // 0.00 = disabled

0 commit comments

Comments
 (0)