@@ -1386,6 +1386,9 @@ struct llama_sampler_penalties {
13861386    const  bool     ignore_eos;
13871387
13881388    ring_buffer<llama_token> prev;
1389+ 
1390+     //  Frequency map to count occurrences of each token in last_tokens
1391+     std::unordered_map<llama_token, size_t > token_count;
13891392};
13901393
13911394static  const  char  * llama_sampler_penalties_name (const  struct  llama_sampler  * /* smpl*/  ) {
@@ -1398,7 +1401,13 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
13981401        return ;
13991402    }
14001403
1404+     if  (ctx->prev .size () >= (size_t ) ctx->penalty_last_n ) {
1405+         if  (--ctx->token_count [ctx->prev .front ()] == 0 ) {
1406+             ctx->token_count .erase (ctx->prev .front ());
1407+         }
1408+     }
14011409    ctx->prev .push_back (token);
1410+     ctx->token_count [token]++;
14021411}
14031412
14041413static  void  llama_sampler_penalties_apply (struct  llama_sampler  * smpl, llama_token_data_array * cur_p) {
@@ -1450,23 +1459,14 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
14501459        }
14511460    }
14521461
1453-     //  Create a frequency map to count occurrences of each token in last_tokens
1454-     //  TODO: optimize this by maintaining the token count in the sampler context
1455-     using  llama_token_cnt = std::unordered_map<llama_token, int >;
1456-     llama_token_cnt token_count;
1457- 
1458-     for  (int  i = 0 ; i < std::min<int >(ctx->penalty_last_n , ctx->prev .size ()); ++i) {
1459-         token_count[ctx->prev .rat (i)]++;
1460-     }
1461- 
14621462    //  Apply frequency and presence penalties to the cur_p
14631463    for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
1464-         const  auto  token_iter = token_count.find (cur_p->data [i].id );
1465-         if  (token_iter == token_count.end ()) {
1464+         const  auto  token_iter = ctx-> token_count .find (cur_p->data [i].id );
1465+         if  (token_iter == ctx-> token_count .end ()) {
14661466            continue ;
14671467        }
14681468
1469-         const  int  count = token_iter->second ;
1469+         const  size_t  count = token_iter->second ;
14701470
14711471        //  The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
14721472        //  This is common fix for this problem, which is to multiply by the penalty instead of dividing.
@@ -1561,6 +1561,7 @@ struct llama_sampler * llama_sampler_init_penalties(
15611561            /*  .penalize_nl     = */   penalize_nl,
15621562            /*  .ignore_eos      = */   ignore_eos,
15631563            /*  .prev            = */   ring_buffer<llama_token>(penalty_last_n),
1564+             /*  .token_count     = */   std::unordered_map<llama_token, size_t >(),
15641565        },
15651566    };
15661567}
0 commit comments