@@ -604,10 +604,73 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
604604static void llama_sampler_dist_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
605605 auto * ctx = (llama_sampler_dist *) smpl->ctx ;
606606
607- // sorting is not necessary here
608- llama_sampler_softmax_impl (cur_p, false );
607+ // edge cases
608+ if (cur_p->size == 0 ) {
609+ cur_p->selected = -1 ;
610+ return ;
611+ }
612+
613+ cur_p->selected = 0 ;
614+
615+ if (cur_p->size == 1 ) {
616+ cur_p->data [0 ].p = 1 .0f ;
617+ return ;
618+ }
619+
620+ // max logit for numerical stability
621+ float max_l = cur_p->data [0 ].logit ;
622+ if (!cur_p->sorted ) {
623+ for (size_t i = 1 ; i < cur_p->size ; ++i) {
624+ max_l = std::max (max_l, cur_p->data [i].logit );
625+ }
626+ }
627+
628+ // apply softmax to obtain the probabilities
629+ double sum_cum = 0 .0f ;
630+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
631+ float p = expf (cur_p->data [i].logit - max_l);
632+ cur_p->data [i].p = p;
633+ sum_cum += p;
634+ }
635+
636+ #if 1
637+ // sample from the obtained probabilities and normalize the probs in a single pass
638+ // this is ~3x faster on Mac with full gpt-oss vocab than the version below
639+ //
640+ std::uniform_real_distribution<double > dist (0 .0f , 1 .0f );
641+ const double rnd = dist (ctx->rng );
642+
643+ double sum_run = 0 .0f ;
644+ const double sum_tgt = sum_cum*rnd;
645+
646+ bool found = false ;
647+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
648+ if (!found) {
649+ // accumulate probs until we reach the target sum
650+ sum_run += cur_p->data [i].p ;
651+ if (sum_run >= sum_tgt) {
652+ cur_p->selected = i;
653+ found = true ;
654+ }
655+ }
656+
657+ // normalize probs
658+ cur_p->data [i].p /= sum_cum;
659+ }
660+
661+ // fallback to the last token (don't think this can happen)
662+ assert (found);
663+ if (!found) {
664+ cur_p->selected = cur_p->size - 1 ;
665+ }
666+ #else
667+ // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
668+ for (size_t i = 0; i < cur_p->size; ++i) {
669+ cur_p->data[i].p /= sum_cum;
670+ }
609671
610672 cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
673+ #endif
611674}
612675
613676static struct llama_sampler * llama_sampler_dist_clone (const struct llama_sampler * smpl) {
0 commit comments