@@ -1062,10 +1062,16 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
10621062// xtc
10631063
10641064struct llama_sampler_xtc {
1065- const float probability;
1066- const float threshold;
1067- const float threshold_max;
1068- const size_t min_keep;
1065+ const float probability;
1066+ const float threshold;
1067+ const float threshold_max;
1068+ const size_t min_keep;
1069+
1070+ const uint32_t seed;
1071+ uint32_t seed_cur;
1072+ float chance;
1073+
1074+ std::mt19937 rng;
10691075};
10701076
10711077static const char * llama_sampler_xtc_name (const struct llama_sampler * /* smpl*/ ) {
@@ -1084,10 +1090,8 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
10841090 || ctx->min_keep <= 2 ) {
10851091 return ;
10861092 }
1087-
1088- std::random_device rd;
1089- float chance = (float )(rd ()%100 - 1 )/100 ;
1090- if (chance > ctx->probability ) return ;
1093+ // chance is calculated on init and on each reset
1094+ if (ctx->chance > ctx->probability ) return ;
10911095
10921096 // in case it's not sorted/recalculated yet
10931097 llama_sampler_softmax_impl (cur_p);
@@ -1117,30 +1121,56 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
11171121
11181122static struct llama_sampler * llama_sampler_xtc_clone (const struct llama_sampler * smpl) {
11191123 const auto * ctx = (const llama_sampler_xtc *) smpl->ctx ;
1120- return llama_sampler_init_xtc (ctx->probability , ctx->threshold , ctx->threshold_max , ctx->min_keep );
1124+ auto * result = llama_sampler_init_xtc (ctx->probability , ctx->threshold , ctx->threshold_max , ctx->min_keep , ctx->seed );
1125+
1126+ // copy the state
1127+ {
1128+ auto * result_ctx = (llama_sampler_xtc *) result->ctx ;
1129+
1130+ result_ctx->rng = ctx->rng ;
1131+ }
1132+
1133+ return result;
11211134}
11221135
11231136static void llama_sampler_xtc_free (struct llama_sampler * smpl) {
11241137 delete (llama_sampler_xtc *) smpl->ctx ;
11251138}
11261139
1140+ static void llama_sampler_xtc_reset (struct llama_sampler * smpl) {
1141+ auto * ctx = (llama_sampler_xtc *) smpl->ctx ;
1142+ ctx->seed_cur = get_rng_seed (ctx->seed );
1143+ ctx->rng .seed (ctx->seed_cur );
1144+
1145+ std::uniform_real_distribution<> distance (0.0 , 1.0 );
1146+ ctx->chance = distance (ctx->rng );
1147+ }
1148+
11271149static struct llama_sampler_i llama_sampler_xtc_i = {
11281150 /* .name = */ llama_sampler_xtc_name,
11291151 /* .accept = */ nullptr ,
11301152 /* .apply = */ llama_sample_xtc_apply,
1131- /* .reset = */ nullptr ,
1153+ /* .reset = */ llama_sampler_xtc_reset ,
11321154 /* .clone = */ llama_sampler_xtc_clone,
11331155 /* .free = */ llama_sampler_xtc_free,
11341156};
11351157
1136- struct llama_sampler * llama_sampler_init_xtc (float p, float t, float t_max, size_t min_keep) {
1158+ struct llama_sampler * llama_sampler_init_xtc (float p, float t, float t_max, size_t min_keep, uint32_t seed) {
1159+ auto seed_cur = get_rng_seed (seed);
1160+ std::uniform_real_distribution<> distance (0.0 , 1.0 );
1161+ auto rng = std::mt19937 (seed_cur);
1162+ float chance = distance (rng);
11371163 return new llama_sampler {
11381164 /* .iface = */ &llama_sampler_xtc_i,
11391165 /* .ctx = */ new llama_sampler_xtc {
11401166 /* .probability = */ p,
11411167 /* .threshold = */ t,
11421168 /* .threshold_max = */ t_max,
11431169 /* .min_keep = */ min_keep,
1170+ /* .seed = */ seed,
1171+ /* .seed_cur = */ seed_cur,
1172+ /* .chance = */ chance,
1173+ /* .rng = */ rng,
11441174 },
11451175 };
11461176}
0 commit comments