@@ -1096,37 +1096,40 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
10961096    //  in case it's not sorted/recalculated yet
10971097    llama_sampler_softmax_impl (cur_p);
10981098
1099-     std::vector<llama_token_data> cur;
1100- 
1101-     int  removed = -1 ; //  to keep one out
1099+     std::vector<llama_token_data> top_tkns;
11021100    int  pos = 0 ;
11031101
1104-     //  going through all candidates from back to front, easier to keep the last of probables
1105-     for  (int  i = (cur_p->size  - 1 ); i >= 0 ; --i) {
1106-         if  (cur_p->data [i].p  >= ctx->threshold  && cur_p->data [i].p  <= ctx->threshold_max ) {
1107-             ++removed;
1108-             if  (removed > 0 ) {
1109-                 //  .logits are used for sorting and calculating .p in llama_sample_softmax_impl
1110-                 cur_p->data [i].logit  = -999 .0f ;
1111-                 cur.emplace_back (cur_p->data [i]);
1112-                 pos = i;
1102+     for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
1103+         if  (cur_p->data [i].p  >= ctx->threshold ) {
1104+             if  (cur_p->data [i].p  <= ctx->threshold_max ) {
1105+                 top_tkns.emplace_back (cur_p->data [i]);
1106+                 //  capture position of the first penalizable
1107+                 if  (pos == -1 ) pos = i;
11131108            }
1114-         }
1109+         }  else   break ; 
11151110    }
11161111
1117-     if  (removed > 0 ) {
1118-         size_t  size_new = cur_p->size  - removed;
1112+     //  check if there are enough penalizable tokens
1113+     if  (top_tkns.size () >= 2 ) {
1114+         //  keep the least probable from top ones
1115+         top_tkns.pop_back ();
1116+ 
1117+         //  define new size
1118+         size_t  to_remove = top_tkns.size ();
1119+         size_t  size_new = cur_p->size  - to_remove;
11191120
1120-         //  shift tokens to remove the penalized ones 
1121+         //  shift tokens starting from pos 
11211122        for  (size_t  i = pos; i < size_new - pos; ++i) {
1122-             cur_p->data [i] = cur_p->data [i + removed ];
1123+             cur_p->data [i] = cur_p->data [i + to_remove ];
11231124        }
11241125
1125-         //  put the prenalized ones at the back
1126-         for  (size_t  i = 0 ; i < cur.size (); ++i) {
1127-             cur_p->data [cur_p->size  - (1  + i)] = cur[i];
1126+         //  penalize top tokens and put them at the back
1127+         for  (size_t  i = 0 ; i < top_tkns.size (); ++i) {
1128+             top_tkns[i].logit  = -999 .0f ;
1129+             cur_p->data [cur_p->size  - (1  + i)] = top_tkns[i];
11281130        }
11291131
1132+         //  resize
11301133        if  (size_new < ctx->min_keep ) size_new = ctx->min_keep ;
11311134        cur_p->size  = size_new;
11321135    }
0 commit comments