@@ -1428,6 +1428,35 @@ void sampler_typical(llama_token_data_array * cur_p, float p, size_t min_keep) {
14281428 cur_p->sorted = false ;
14291429}
14301430
1431+ void sample_top_n_sigma (llama_token_data_array * cur_p, float nsigma) {
1432+
1433+ // find max logit and calculate mean
1434+ float nsigmax = cur_p->data [0 ].logit ;
1435+ float logits_sum = 0 ;
1436+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1437+ if (cur_p->data [i].logit > nsigmax) {
1438+ nsigmax = cur_p->data [i].logit ;
1439+ }
1440+ logits_sum += cur_p->data [i].logit ;
1441+ }
1442+ float nsigmean = logits_sum / cur_p->size ;
1443+
1444+ // calculate standard deviation
1445+ float nsigacc = 0 ;
1446+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1447+ nsigacc += pow (cur_p->data [i].logit - nsigmean, 2 );
1448+ }
1449+ float nsigstd = sqrt (nsigacc / cur_p->size );
1450+
1451+ // apply mask
1452+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1453+ if (cur_p->data [i].logit < nsigmax - (nsigma * nsigstd)) {
1454+ cur_p->data [i].logit -= 999 .0f ;
1455+ }
1456+ }
1457+ sample_softmax (cur_p);
1458+ }
1459+
14311460void sample_entropy (llama_token_data_array * cur_p, float min_temp, float max_temp, float exponent_val, float smoothing_factor) {
14321461 // no need to do anything if there is only one (or zero) candidates
14331462 if (cur_p->size <= 1 ) {
@@ -1561,7 +1590,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
15611590
15621591}
15631592
1564- int SampleLogits (const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
1593+ int SampleLogits (const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float nsigma, float temp, std::mt19937 & rng,
15651594int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, float xtc_threshold, float xtc_probability,
15661595const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor)
15671596{
@@ -1584,8 +1613,10 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
15841613 sample_grammar (file_format, n_vocab, &candidates_p, grammar);
15851614 }
15861615
1587- // dry always first as logits cannot be resorted
1588- sample_dry (n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p);
1616+ if (nsigma <= 0 .0f ){
1617+ // dry always first as logits cannot be resorted
1618+ sample_dry (n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p);
1619+ }
15891620
15901621 // prefilter to top 3k tokens for improved speed
15911622 sample_top_k (&candidates_p, 3000 );
@@ -1605,6 +1636,25 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
16051636 id = sample_token_mirostat_v2 (&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu);
16061637 }
16071638 }
1639+ else if (nsigma > 0 .0f )
1640+ {
1641+ sample_top_k (&candidates_p, top_k);
1642+ if (dynatemp_range > 0 ) {
1643+ float dynatemp_min = temp - dynatemp_range;
1644+ float dynatemp_max = temp + dynatemp_range;
1645+ // do not allow negative values
1646+ dynatemp_min = dynatemp_min < 0 ? 0 : dynatemp_min;
1647+ dynatemp_max = dynatemp_max < 0 ? 0 : dynatemp_max;
1648+ dynatemp_exponent = dynatemp_exponent < 0 ? 0 : dynatemp_exponent;
1649+ sample_entropy (&candidates_p, dynatemp_min, dynatemp_max, dynatemp_exponent, smoothing_factor);
1650+ } else {
1651+ sample_temperature (&candidates_p, temp, smoothing_factor);
1652+ }
1653+ sample_top_n_sigma (&candidates_p, nsigma);
1654+
1655+ sample_xtc (&candidates_p, xtc_threshold, xtc_probability, rng);
1656+ id = sample_token (&candidates_p, rng);
1657+ }
16081658 else
16091659 {
16101660 for (int i = 0 ; i < sampler_order.size (); i++)
@@ -2999,6 +3049,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
29993049 kcpp_data->min_p = inputs.min_p ;
30003050 kcpp_data->typical_p = inputs.typical_p ;
30013051 kcpp_data->tfs_z = inputs.tfs ;
3052+ kcpp_data->nsigma = inputs.nsigma ;
30023053 kcpp_data->temp = inputs.temperature ;
30033054 kcpp_data->repeat_last_n = inputs.rep_pen_range ;
30043055 kcpp_data->rep_pen_slope = inputs.rep_pen_slope ;
@@ -3529,6 +3580,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
35293580 const float presence_penalty = kcpp_data->presence_penalty ;
35303581 const float typical_p = kcpp_data->typical_p ;
35313582 const float tfs_z = kcpp_data->tfs_z ;
3583+ const float nsigma = kcpp_data->nsigma ;
35323584 const float dynatemp_range = kcpp_data->dynatemp_range ;
35333585 const float dynatemp_exponent = kcpp_data->dynatemp_exponent ;
35343586 const float smoothing_factor = kcpp_data->smoothing_factor ;
@@ -3624,7 +3676,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
36243676 }
36253677
36263678 id = SampleLogits (logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, kcpp_data->rep_pen_slope , presence_penalty,
3627- top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
3679+ top_k, top_a, top_p, min_p, typical_p, tfs_z, nsigma, temp, rng,
36283680 kcpp_data->mirostat , kcpp_data->mirostat_tau , kcpp_data->mirostat_eta ,
36293681 kcpp_data->dry_multiplier , kcpp_data->dry_base ,
36303682 kcpp_data->dry_allowed_length , kcpp_data->dry_penalty_last_n , kcpp_data->xtc_threshold , kcpp_data->xtc_probability ,
0 commit comments