@@ -1396,19 +1396,15 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
13961396// penalties
13971397
13981398struct llama_sampler_penalties {
1399- const int32_t n_vocab;
1400- const llama_token special_eos_id;
1401- const llama_token linefeed_id;
1402-
14031399 const int32_t penalty_last_n;
14041400 const float penalty_repeat;
14051401 const float penalty_freq;
14061402 const float penalty_present;
14071403
1408- const bool penalize_nl;
1409- const bool ignore_eos;
1410-
14111404 ring_buffer<llama_token> prev;
1405+
1406+ // a frequency map to count token occurrences
1407+ std::unordered_map<llama_token, int > token_count;
14121408};
14131409
14141410static const char * llama_sampler_penalties_name (const struct llama_sampler * /* smpl*/ ) {
@@ -1421,76 +1417,50 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
14211417 return ;
14221418 }
14231419
1424- ctx->prev .push_back (token);
1425- }
1426-
1427- static void llama_sampler_penalties_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1428- auto * ctx = (llama_sampler_penalties *) smpl->ctx ;
1420+ ctx->token_count [token]++;
14291421
1430- if (ctx->ignore_eos ) {
1431- assert (ctx->special_eos_id >= 0 );
1422+ // if the ring buffer is full, remove the oldest token
1423+ if (ctx->prev .size () >= (size_t ) ctx->penalty_last_n ) {
1424+ const auto old = ctx->prev .front ();
14321425
1433- // optimistically check if the candidates are not yet sorted/shuffled/truncated
1434- if (cur_p->size > (size_t ) ctx->special_eos_id && cur_p->data [ctx->special_eos_id ].id == ctx->special_eos_id ) {
1435- cur_p->data [ctx->special_eos_id ].logit = -INFINITY;
1436- } else {
1437- // else, search for the special EOS token
1438- for (size_t i = 0 ; i < cur_p->size ; ++i) {
1439- if (cur_p->data [i].id == ctx->special_eos_id ) {
1440- cur_p->data [i].logit = -INFINITY;
1441- break ;
1442- }
1443- }
1426+ ctx->token_count [old]--;
1427+ if (ctx->token_count [old] == 0 ) {
1428+ ctx->token_count .erase (old);
14441429 }
14451430 }
14461431
1447- if ((ctx->penalty_last_n == 0 ) ||
1448- (ctx->penalty_repeat == 1 .0f && ctx->penalty_freq == 0 .0f && ctx->penalty_present == 0 .0f )) {
1449- return ;
1450- }
1451-
1452- bool nl_found = false ;
1453- size_t nl_idx = 0 ;
1454- float nl_logit = -INFINITY;
1455- if (!ctx->penalize_nl ) {
1456- assert (ctx->linefeed_id >= 0 );
1432+ ctx->prev .push_back (token);
14571433
1458- // optimistically check if the candidates are not yet sorted/shuffled/truncated
1459- if (cur_p->size > (size_t ) ctx->linefeed_id && cur_p->data [ctx->linefeed_id ].id == ctx->linefeed_id ) {
1460- nl_found = true ;
1461- nl_idx = ctx->linefeed_id ;
1462- nl_logit = cur_p->data [ctx->linefeed_id ].logit ;
1463- } else {
1464- // else, search for the linefeed token
1465- for (size_t i = 0 ; i < cur_p->size ; ++i) {
1466- if (cur_p->data [i].id == ctx->linefeed_id ) {
1467- nl_found = true ;
1468- nl_idx = i;
1469- nl_logit = cur_p->data [i].logit ;
1470- break ;
1471- }
1472- }
1473- }
1434+ #if 0
1435+ // sanity check
1436+ std::unordered_map<llama_token, int> tmp;
1437+ for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1438+ tmp[ctx->prev.rat(i)]++;
14741439 }
14751440
1476- // Create a frequency map to count occurrences of each token in last_tokens
1477- // TODO: optimize this by maintaining the token count in the sampler context
1478- using llama_token_cnt = std::unordered_map<llama_token, int >;
1479- llama_token_cnt token_count;
1441+ assert(ctx->token_count == tmp);
1442+ #endif
1443+ }
1444+
1445+ static void llama_sampler_penalties_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1446+ auto * ctx = (llama_sampler_penalties *) smpl->ctx ;
14801447
1481- for (int i = 0 ; i < std::min<int >(ctx->penalty_last_n , ctx->prev .size ()); ++i) {
1482- token_count[ctx->prev .rat (i)]++;
1448+ if ((ctx->penalty_last_n == 0 ) ||
1449+ (ctx->penalty_repeat == 1 .0f && ctx->penalty_freq == 0 .0f && ctx->penalty_present == 0 .0f )) {
1450+ return ;
14831451 }
14841452
14851453 // Apply frequency and presence penalties to the cur_p
14861454 for (size_t i = 0 ; i < cur_p->size ; ++i) {
1487- const auto token_iter = token_count.find (cur_p->data [i].id );
1488- if (token_iter == token_count.end ()) {
1455+ const auto token_iter = ctx-> token_count .find (cur_p->data [i].id );
1456+ if (token_iter == ctx-> token_count .end ()) {
14891457 continue ;
14901458 }
14911459
14921460 const int count = token_iter->second ;
14931461
1462+ assert (count > 0 && count <= ctx->penalty_last_n );
1463+
14941464 // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
14951465 // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
14961466 if (cur_p->data [i].logit <= 0 ) {
@@ -1503,30 +1473,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
15031473 }
15041474
15051475 cur_p->sorted = false ;
1506-
1507- if (!ctx->penalize_nl && nl_found) {
1508- // restore the logit of the newline token if it was penalized
1509- cur_p->data [nl_idx].logit = nl_logit;
1510- }
15111476}
15121477
15131478static void llama_sampler_penalties_reset (struct llama_sampler * smpl) {
15141479 auto * ctx = (llama_sampler_penalties *) smpl->ctx ;
15151480 ctx->prev .clear ();
1481+ ctx->token_count .clear ();
15161482}
15171483
15181484static struct llama_sampler * llama_sampler_penalties_clone (const struct llama_sampler * smpl) {
15191485 const auto * ctx = (const llama_sampler_penalties *) smpl->ctx ;
15201486 auto * result = llama_sampler_init_penalties (
1521- ctx->n_vocab ,
1522- ctx->special_eos_id ,
1523- ctx->linefeed_id ,
15241487 ctx->penalty_last_n ,
15251488 ctx->penalty_repeat ,
15261489 ctx->penalty_freq ,
1527- ctx->penalty_present ,
1528- ctx->penalize_nl ,
1529- ctx->ignore_eos );
1490+ ctx->penalty_present );
15301491
15311492 // copy the state
15321493 {
@@ -1552,38 +1513,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
15521513};
15531514
15541515struct llama_sampler * llama_sampler_init_penalties (
1555- int32_t n_vocab,
1556- llama_token special_eos_id,
1557- llama_token linefeed_id,
15581516 int32_t penalty_last_n,
15591517 float penalty_repeat,
15601518 float penalty_freq,
1561- float penalty_present,
1562- bool penalize_nl,
1563- bool ignore_eos) {
1564- if (linefeed_id == LLAMA_TOKEN_NULL) {
1565- penalize_nl = true ;
1566- }
1567-
1568- if (special_eos_id == LLAMA_TOKEN_NULL) {
1569- ignore_eos = false ;
1570- }
1571-
1519+ float penalty_present) {
15721520 penalty_last_n = std::max (penalty_last_n, 0 );
15731521
15741522 return new llama_sampler {
15751523 /* .iface = */ &llama_sampler_penalties_i,
15761524 /* .ctx = */ new llama_sampler_penalties {
1577- /* .n_vocab = */ n_vocab,
1578- /* .special_eos_id = */ special_eos_id,
1579- /* .linefeed_id = */ linefeed_id,
15801525 /* .penalty_last_n = */ penalty_last_n,
15811526 /* .penalty_repeat = */ penalty_repeat,
15821527 /* .penalty_freq = */ penalty_freq,
15831528 /* .penalty_present = */ penalty_present,
1584- /* .penalize_nl = */ penalize_nl,
1585- /* .ignore_eos = */ ignore_eos,
15861529 /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1530+ /* .token_count = */ {},
15871531 },
15881532 };
15891533}
@@ -1611,7 +1555,8 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
16111555 if (word.find (str) != std::string::npos) {
16121556 token_sequences.emplace (token_id, std::vector<llama_token>());
16131557 } else {
1614- size_t word_len = word.size (), str_len = str.size ();
1558+ size_t word_len = word.size ();
1559+ size_t str_len = str.size ();
16151560 size_t pos = -1 ;
16161561 while ((pos = word.find (str[0 ], pos + 1 )) != std::string::npos) {
16171562 bool match = true ;
0 commit comments