@@ -1075,37 +1075,43 @@ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/
10751075static  void  llama_sample_xtc_apply (struct  llama_sampler  * smpl, llama_token_data_array * cur_p) {
10761076    const  auto  * ctx = (llama_sampler_xtc *) smpl->ctx ;
10771077
1078-     if  (ctx->probability  <= 0 .0f  || ctx->threshold  <= 0 .0f  || cur_p->size  <= 1  || ctx->min_keep  <= 2 ) {
1078+     if  (ctx->probability  <= 0 .0f  
1079+         || ctx->threshold  <= 0 .0f  
1080+         || ctx->threshold  >= 1 .0f  
1081+         || ctx->threshold_max  <= 0 .0f  
1082+         || ctx->threshold_max  <= ctx->threshold  
1083+         || cur_p->size  <= 2  
1084+         || ctx->min_keep  <= 2 ) {
10791085        return ;
10801086    }
10811087
10821088    std::random_device rd;
1083-     float  chance = (float )(rd ()%100 )/100 ;
1089+     float  chance = (float )(rd ()%100  -  1 )/100 ;
10841090    if  (chance > ctx->probability ) return ;
1091+ 
10851092    //  in case it's not sorted/recalculated yet
10861093    llama_sampler_softmax_impl (cur_p);
10871094
1088-     int  removed  = 0 ;
1095+     int  found  = 0 ;
10891096    //  going through all candidates from back to front, easier to keep the last of probables
10901097    for  (int  i = (cur_p->size  - 1 ); i >= 0 ; --i) {
10911098        if  (cur_p->data [i].p  >= ctx->threshold  && cur_p->data [i].p  <= ctx->threshold_max ) {
1092-             ++removed ;
1093-             if  (removed  > 1 ) {
1099+             ++found ;
1100+             if  (found  > 1 ) {
10941101                //  .logits are used for sorting and calculating .p in llama_sample_softmax_impl
10951102                cur_p->data [i].logit  = -999 .0f ;
10961103            }
10971104        }
10981105    }
10991106
1100-     if  (removed  > 1 ) {
1107+     if  (found  > 1 ) {
11011108        //  sorting with new logits, ex-last probable will be the first anyway
11021109        std::sort (cur_p->data , cur_p->data  + cur_p->size , [](const  llama_token_data & a, const  llama_token_data & b) {
11031110            return  a.logit  > b.logit ;
11041111        });
1105-         cur_p->sorted  = true ;
11061112
11071113        //  resizing now that penalized tokens are at the back
1108-         cur_p->size  = cur_p->size  - removed  + 1 ;
1114+         cur_p->size  = cur_p->size  - found  + 1 ;
11091115    }
11101116}
11111117
0 commit comments