@@ -681,6 +681,8 @@ static void llama_sampler_min_p_addon_apply(struct llama_sampler * smpl, llama_t
681681 if (!filtered_tokens.empty () && filtered_tokens.size () >= ctx->min_keep ) {
682682 memcpy (cur_p->data , filtered_tokens.data (), filtered_tokens.size ()*sizeof (llama_token_data));
683683 cur_p->size = filtered_tokens.size ();
684+ // Guard against a single choice
685+ if (cur_p->size < 2 ) cur_p->size = 2 ;
684686 min_p_applied = true ;
685687 }
686688 }
@@ -706,6 +708,9 @@ static void llama_sampler_min_p_addon_apply(struct llama_sampler * smpl, llama_t
706708 }
707709 }
708710
711+ // Guard against a single choice
712+ if (i < 2 ) i = 2 ;
713+
709714 // Resize the output vector to keep only the matching tokens
710715 cur_p->size = i;
711716 }
@@ -2190,5 +2195,111 @@ struct llama_sampler * llama_sampler_init_logit_bias_addon(
21902195 );
21912196}
21922197
2198+ // logit-bias-start
2199+
2200+ struct llama_sampler_logit_bias_start_addon {
2201+ const int32_t n_vocab;
2202+
2203+ const std::vector<llama_logit_bias> logit_bias;
2204+
2205+ std::vector<llama_logit_bias> to_search;
2206+ };
2207+
2208+ static const char * llama_sampler_logit_bias_start_addon_name (const struct llama_sampler * /* smpl*/ ) {
2209+ return " logit-bias" ;
2210+ }
2211+
2212+ static void llama_sampler_logit_bias_start_addon_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2213+ auto * ctx = (llama_sampler_logit_bias_start_addon *) smpl->ctx ;
2214+
2215+ if (ctx->logit_bias .empty ()) {
2216+ // std::string logits_orig = "\nLOGITS: EMPTY\n";
2217+ // if (test_dumbed_logits_biased == false) {
2218+ // writeToFile("logit_biasing.txt", logits_orig);
2219+ // test_dumbed_logits_biased = true;
2220+ // }
2221+ return ;
2222+ }
2223+
2224+ ctx->to_search .clear ();
2225+
2226+ // std::string logits_orig = "\nLOGITS:\n";
2227+ // std::string logits_positive = "\nLOGITS POS:\n";
2228+ // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
2229+ for (const auto & lb : ctx->logit_bias ) {
2230+ if (lb.token >= 0 && cur_p->size > (size_t ) lb.token && cur_p->data [lb.token ].id == lb.token ) {
2231+ // if (lb.bias < 0) {
2232+ // logits_orig += std::to_string(cur_p->data[lb.token].id) + ": " + std::to_string(cur_p->data[lb.token].logit) + " -> ";
2233+ // } else logits_positive += std::to_string(cur_p->data[lb.token].id) + ": " + std::to_string(cur_p->data[lb.token].logit) + " -> ";
2234+ cur_p->data [lb.token ].logit += lb.bias ;
2235+ // if (lb.bias < 0) {
2236+ // logits_orig += std::to_string(cur_p->data[lb.token].logit) + ";\n";
2237+ // } else logits_positive += std::to_string(cur_p->data[lb.token].logit) + ";\n";
2238+ } else {
2239+ ctx->to_search .push_back (lb);
2240+ }
2241+ }
2242+
2243+ if (ctx->to_search .empty ()) {
2244+ // if (test_dumbed_logits_biased == false) {
2245+ // logits_orig += logits_positive + "\nNO SEARCH\n";
2246+ // writeToFile("logit_biasing.txt", logits_orig);
2247+ // test_dumbed_logits_biased = true;
2248+ // }
2249+ return ;
2250+ }
2251+
2252+ // search for the remaining candidates that were not found in the previous step
2253+ // logits_orig += "\nSEARCH:\n";
2254+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
2255+ for (const auto & lb : ctx->to_search ) {
2256+ if (cur_p->data [i].id == lb.token ) {
2257+ // logits_orig += std::to_string(cur_p->data[i].logit) + "->";
2258+ cur_p->data [i].logit += lb.bias ;
2259+ // logits_orig += std::to_string(cur_p->data[i].logit) + ";\n";
2260+ break ;
2261+ }
2262+ }
2263+ }
2264+
2265+ // if (test_dumbed_logits_biased == false) {
2266+ // logits_orig += logits_positive;
2267+ // writeToFile("logit_biasing.txt", logits_orig);
2268+ // test_dumbed_logits_biased = true;
2269+ // }
2270+ }
2271+
2272+ static struct llama_sampler * llama_sampler_logit_bias_start_addon_clone (const struct llama_sampler * smpl) {
2273+ const auto * ctx = (const llama_sampler_logit_bias_start_addon *) smpl->ctx ;
2274+ return llama_sampler_init_logit_bias_start_addon (ctx->n_vocab , ctx->logit_bias .size (), ctx->logit_bias .data ());
2275+ }
2276+
2277+ static void llama_sampler_logit_bias_start_addon_free (struct llama_sampler * smpl) {
2278+ delete (llama_sampler_logit_bias_start_addon *) smpl->ctx ;
2279+ }
2280+
2281+ static struct llama_sampler_i llama_sampler_logit_bias_start_addon_i = {
2282+ /* .name = */ llama_sampler_logit_bias_start_addon_name,
2283+ /* .accept = */ nullptr ,
2284+ /* .apply = */ llama_sampler_logit_bias_start_addon_apply,
2285+ /* .reset = */ nullptr ,
2286+ /* .clone = */ llama_sampler_logit_bias_start_addon_clone,
2287+ /* .free = */ llama_sampler_logit_bias_start_addon_free,
2288+ };
2289+
2290+ struct llama_sampler * llama_sampler_init_logit_bias_start_addon (
2291+ int32_t n_vocab,
2292+ int32_t n_logit_bias,
2293+ const llama_logit_bias * logit_bias) {
2294+ return llama_sampler_init (
2295+ /* .iface = */ &llama_sampler_logit_bias_start_addon_i,
2296+ /* .ctx = */ new llama_sampler_logit_bias_start_addon {
2297+ /* .n_vocab = */ n_vocab,
2298+ /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
2299+ /* .to_search = */ {},
2300+ }
2301+ );
2302+ }
2303+
21932304
21942305
0 commit comments