@@ -260,6 +260,7 @@ static void llama_log_softmax(float * array, size_t size) {
260260*/
261261
262262static void llama_sampler_temp_impl (llama_token_data_array * cur_p, float temp) {
263+ cur_p->normalized = false ;
263264 if (temp <= 0 .0f ) {
264265 // find the token with the highest logit and set the rest to -inf
265266 size_t max_i = 0 ;
@@ -309,6 +310,7 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_s
309310 for (size_t i = 0 ; i < cur_p->size ; ++i) {
310311 cur_p->data [i].p /= cum_sum;
311312 }
313+ cur_p->normalized = true ;
312314}
313315
314316static void llama_sampler_top_k_impl (llama_token_data_array * cur_p, int32_t k) {
@@ -328,6 +330,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
328330 }
329331
330332 cur_p->size = k;
333+ cur_p->normalized = false ;
331334}
332335
333336static uint32_t get_rng_seed (uint32_t seed) {
@@ -422,6 +425,7 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
422425 llama_token_data_array cur_p = {
423426 /* .data = */ cur.data (),
424427 /* .size = */ cur.size (),
428+ /* .normalized = */ false ,
425429 /* .selected = */ -1 ,
426430 /* .sorted = */ false ,
427431 };
@@ -614,6 +618,23 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
614618
615619 if (cur_p->size == 1 ) {
616620 cur_p->data [0 ].p = 1 .0f ;
621+ cur_p->normalized = true ;
622+ return ;
623+ }
624+
625+ if (cur_p->normalized ) {
626+ std::uniform_real_distribution<double > dist (0 .0f , 1 .0f );
627+ const double rnd = dist (ctx->rng );
628+ double sum_run = 0 .0f ;
629+
630+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
631+ sum_run += cur_p->data [i].p ;
632+ if (sum_run >= rnd) {
633+ cur_p->selected = i;
634+ return ;
635+ }
636+ }
637+ cur_p->selected = cur_p->size - 1 ;
617638 return ;
618639 }
619640
@@ -663,13 +684,15 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
663684 if (!found) {
664685 cur_p->selected = cur_p->size - 1 ;
665686 }
687+ cur_p->normalized = true ;
666688#else
667689 // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
668690 for (size_t i = 0; i < cur_p->size; ++i) {
669691 cur_p->data[i].p /= sum_cum;
670692 }
671693
672694 cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
695+ cur_p->normalized = true;
673696 #endif
674697}
675698
@@ -780,7 +803,9 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d
780803 return ;
781804 }
782805
783- llama_sampler_softmax_impl (cur_p, false );
806+ if (!cur_p->normalized ) {
807+ llama_sampler_softmax_impl (cur_p, false );
808+ }
784809
785810 size_t k = cur_p->size ;
786811 auto * pdata = cur_p->data ;
@@ -826,6 +851,7 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d
826851 }
827852
828853 cur_p->size = last_idx;
854+ cur_p->normalized = false ;
829855}
830856
831857static struct llama_sampler * llama_sampler_top_p_clone (const struct llama_sampler * smpl) {
@@ -897,6 +923,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
897923 if (!filtered_tokens.empty () && filtered_tokens.size () >= ctx->min_keep ) {
898924 std::copy (filtered_tokens.begin (), filtered_tokens.end (), cur_p->data );
899925 cur_p->size = filtered_tokens.size ();
926+ cur_p->normalized = false ;
900927 min_p_applied = true ;
901928 }
902929 }
@@ -919,6 +946,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
919946
920947 // Resize the output vector to keep only the matching tokens
921948 cur_p->size = i;
949+ cur_p->normalized = false ;
922950 }
923951}
924952
@@ -971,7 +999,9 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token
971999 }
9721000
9731001 // Compute the softmax of logits and calculate entropy
974- llama_sampler_softmax_impl (cur_p, true );
1002+ if (!cur_p->normalized ) {
1003+ llama_sampler_softmax_impl (cur_p, true );
1004+ }
9751005
9761006 float entropy = 0 .0f ;
9771007 for (size_t i = 0 ; i < cur_p->size ; ++i) {
@@ -1019,6 +1049,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token
10191049 std::copy (cur_p_new.begin (), cur_p_new.end (), cur_p->data );
10201050 cur_p->size = cur_p_new.size ();
10211051 cur_p->sorted = false ;
1052+ cur_p->normalized = false ;
10221053}
10231054
10241055static struct llama_sampler * llama_sampler_typical_clone (const struct llama_sampler * smpl) {
@@ -1120,7 +1151,9 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
11201151 // Calculate maximum possible entropy
11211152 float max_entropy = -logf (1 .0f / cur_p->size );
11221153
1123- llama_sampler_softmax_impl (cur_p, true );
1154+ if (!cur_p->normalized ) {
1155+ llama_sampler_softmax_impl (cur_p, true );
1156+ }
11241157
11251158 // Calculate entropy of the softmax probabilities
11261159 float entropy = 0 .0f ;
@@ -1162,6 +1195,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
11621195 for (size_t i = 0 ; i < cur_p->size ; ++i) {
11631196 cur_p->data [i].p /= cum_sum_double; // Re-normalize the probabilities
11641197 }
1198+ cur_p->normalized = true ;
11651199
11661200 #ifdef DEBUG
11671201 // Print the updated top 25 probabilities after temperature scaling
@@ -1236,7 +1270,9 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
12361270 return ;
12371271 }
12381272
1239- llama_sampler_softmax_impl (cur_p, true );
1273+ if (!cur_p->normalized ) {
1274+ llama_sampler_softmax_impl (cur_p, true );
1275+ }
12401276
12411277 int pos_last = 0 ;
12421278
@@ -1251,6 +1287,7 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
12511287 if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0 ) {
12521288 cur_p->data += pos_last;
12531289 cur_p->size -= pos_last;
1290+ cur_p->normalized = false ;
12541291 }
12551292}
12561293
@@ -1327,7 +1364,9 @@ static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*s
13271364static void llama_sampler_mirostat_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
13281365 auto * ctx = (llama_sampler_mirostat *) smpl->ctx ;
13291366
1330- llama_sampler_softmax_impl (cur_p, true );
1367+ if (!cur_p->normalized ) {
1368+ llama_sampler_softmax_impl (cur_p, true );
1369+ }
13311370
13321371 // Estimate s_hat using the most probable m tokens
13331372 float s_hat = 0.0 ;
@@ -1433,7 +1472,9 @@ static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler *
14331472static void llama_sampler_mirostat_v2_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
14341473 auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx ;
14351474
1436- llama_sampler_softmax_impl (cur_p, true );
1475+ if (!cur_p->normalized ) {
1476+ llama_sampler_softmax_impl (cur_p, true );
1477+ }
14371478
14381479 // Truncate the words with surprise values greater than mu
14391480 cur_p->size = std::distance (cur_p->data , std::find_if (cur_p->data , cur_p->data + cur_p->size , [&](const llama_token_data & candidate) {
@@ -1775,6 +1816,7 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
17751816 }
17761817
17771818 cur_p->sorted = false ;
1819+ cur_p->normalized = false ;
17781820}
17791821
17801822static void llama_sampler_penalties_reset (struct llama_sampler * smpl) {
@@ -2193,6 +2235,7 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
21932235 }
21942236
21952237 cur_p->sorted = false ;
2238+ cur_p->normalized = false ;
21962239}
21972240
21982241static void llama_sampler_dry_reset (struct llama_sampler * smpl) {
@@ -2344,6 +2387,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
23442387 }
23452388
23462389 if (ctx->to_search .empty ()) {
2390+ cur_p->normalized = false ;
23472391 return ;
23482392 }
23492393
@@ -2356,6 +2400,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
23562400 }
23572401 }
23582402 }
2403+ cur_p->normalized = false ;
23592404}
23602405
23612406static struct llama_sampler * llama_sampler_logit_bias_clone (const struct llama_sampler * smpl) {
@@ -2408,7 +2453,9 @@ static const char * llama_sampler_infill_name(const struct llama_sampler * /*smp
24082453static void llama_sampler_infill_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
24092454 auto * ctx = (llama_sampler_infill *) smpl->ctx ;
24102455
2411- llama_sampler_softmax_impl (cur_p, true );
2456+ if (!cur_p->normalized ) {
2457+ llama_sampler_softmax_impl (cur_p, true );
2458+ }
24122459
24132460#if defined(GGML_DEBUG_SAMPLER_INFILL)
24142461#define LOG_DBG_CUR LLAMA_LOG_DEBUG
@@ -2457,6 +2504,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
24572504 for (size_t i = 0 ; i < cur_p->size ; ++i) {
24582505 cur_p->data [i].p /= p_sum;
24592506 }
2507+ cur_p->normalized = true ;
24602508
24612509 return ;
24622510 }
@@ -2542,6 +2590,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
25422590 cur_p->size = 1 ;
25432591 cur_p->data [0 ].id = ctx->vocab ->token_eot ();
25442592 cur_p->data [0 ].logit = 1 .0f ;
2593+ cur_p->normalized = true ;
25452594
25462595 return ;
25472596 }
@@ -2579,6 +2628,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
25792628
25802629 LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
25812630 }
2631+ cur_p->normalized = true ;
25822632
25832633#undef LOG_DBG_CUR
25842634}
0 commit comments