Skip to content

Commit 4f8e55b

Browse files
authored
Fixed RNG to be reproduceable
Thanks to @slaren for directions
1 parent f2a2a61 commit 4f8e55b

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

common/sampling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
185185
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
186186
break;
187187
case GPT_SAMPLER_TYPE_XTC:
188-
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_p, params.xtc_t, params.xtc_t_max, params.min_keep));
188+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_p, params.xtc_t, params.xtc_t_max, params.min_keep, params.seed));
189189
break;
190190
case GPT_SAMPLER_TYPE_TFS_Z:
191191
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));

src/llama-sampling.cpp

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,10 +1062,16 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
10621062
// xtc
10631063

10641064
struct llama_sampler_xtc {
1065-
const float probability;
1066-
const float threshold;
1067-
const float threshold_max;
1068-
const size_t min_keep;
1065+
const float probability;
1066+
const float threshold;
1067+
const float threshold_max;
1068+
const size_t min_keep;
1069+
1070+
const uint32_t seed;
1071+
uint32_t seed_cur;
1072+
float chance;
1073+
1074+
std::mt19937 rng;
10691075
};
10701076

10711077
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
@@ -1084,10 +1090,8 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
10841090
|| ctx->min_keep <= 2) {
10851091
return;
10861092
}
1087-
1088-
std::random_device rd;
1089-
float chance = (float)(rd()%100 - 1)/100;
1090-
if (chance > ctx->probability) return;
1093+
// chance is calculated on init and on each reset
1094+
if (ctx->chance > ctx->probability) return;
10911095

10921096
// in case it's not sorted/recalculated yet
10931097
llama_sampler_softmax_impl(cur_p);
@@ -1117,30 +1121,56 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
11171121

11181122
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
11191123
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
1120-
return llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->threshold_max, ctx->min_keep);
1124+
auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->threshold_max, ctx->min_keep, ctx->seed);
1125+
1126+
// copy the state
1127+
{
1128+
auto * result_ctx = (llama_sampler_xtc *) result->ctx;
1129+
1130+
result_ctx->rng = ctx->rng;
1131+
}
1132+
1133+
return result;
11211134
}
11221135

11231136
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
11241137
delete (llama_sampler_xtc *) smpl->ctx;
11251138
}
11261139

1140+
static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
1141+
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1142+
ctx->seed_cur = get_rng_seed(ctx->seed);
1143+
ctx->rng.seed(ctx->seed_cur);
1144+
1145+
std::uniform_real_distribution<> distance(0.0, 1.0);
1146+
ctx->chance = distance(ctx->rng);
1147+
}
1148+
11271149
static struct llama_sampler_i llama_sampler_xtc_i = {
11281150
/* .name = */ llama_sampler_xtc_name,
11291151
/* .accept = */ nullptr,
11301152
/* .apply = */ llama_sample_xtc_apply,
1131-
/* .reset = */ nullptr,
1153+
/* .reset = */ llama_sampler_xtc_reset,
11321154
/* .clone = */ llama_sampler_xtc_clone,
11331155
/* .free = */ llama_sampler_xtc_free,
11341156
};
11351157

1136-
struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep) {
1158+
struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep, uint32_t seed) {
1159+
auto seed_cur = get_rng_seed(seed);
1160+
std::uniform_real_distribution<> distance(0.0, 1.0);
1161+
auto rng = std::mt19937(seed_cur);
1162+
float chance = distance(rng);
11371163
return new llama_sampler {
11381164
/* .iface = */ &llama_sampler_xtc_i,
11391165
/* .ctx = */ new llama_sampler_xtc {
11401166
/* .probability = */ p,
11411167
/* .threshold = */ t,
11421168
/* .threshold_max = */ t_max,
11431169
/* .min_keep = */ min_keep,
1170+
/* .seed = */ seed,
1171+
/* .seed_cur = */ seed_cur,
1172+
/* .chance = */ chance,
1173+
/* .rng = */ rng,
11441174
},
11451175
};
11461176
}

0 commit comments

Comments
 (0)