@@ -596,6 +596,45 @@ class chat
596596 // std::getline(std::cin, pause);
597597 }
598598
599+ bool logit_bias_check_exact (std::string_view token_str) {
600+ for (auto word : params.sparams .logit_bias_strings_exact ) {
601+ if (token_str == word) return true ;
602+ }
603+
604+ return false ;
605+ }
606+
607+ bool logit_bias_check_beginning (std::string_view token_str) {
608+ for (auto word : params.sparams .logit_bias_strings_beginning ) {
609+ if ((token_str.find (word) == 0 && (token_str.length () - word.length ()) < 4 ) ||
610+ (token_str.length () > 2 && word.find (token_str) == 0 )
611+ ) return true ;
612+ }
613+
614+ return false ;
615+ }
616+
617+ bool logit_bias_check_ending (std::string_view token_str) {
618+ for (auto word : params.sparams .logit_bias_strings_ending ) {
619+ auto token_str_pos = word.find (token_str);
620+ if (token_str_pos == (token_str.length () - 1 )) return true ;
621+ }
622+
623+ return false ;
624+ }
625+
626+ bool logit_bias_checks (std::string token_str) {
627+ if (token_str.front () == ' ' ) {
628+ token_str = token_str.substr (1 );
629+ }
630+
631+ if (token_str.back () == ' ' ) {
632+ token_str.pop_back ();
633+ }
634+
635+ return logit_bias_check_exact (token_str) || logit_bias_check_beginning (token_str) || logit_bias_check_ending (token_str);
636+ }
637+
599638 void logit_bias_postfill (llama_token & id, std::string token_str) {
600639 // cutting spaces since there are "duplicated" tokens with them
601640 if (token_str.front () == ' ' ) {
@@ -687,14 +726,21 @@ class chat
687726 }
688727
689728 void processByVocab (std::string safeguard_string) {
729+ bool has_logit_biases_detailed = (params.sparams .logit_bias_strings_exact .size () || params.sparams .logit_bias_strings_beginning .size () || params.sparams .logit_bias_strings_ending .size ());
730+
690731 bool has_logit_biases = (params.sparams .logit_bias_strings .size () || params.sparams .logit_bias_strings_ext .size ());
691732 bool has_logit_biases_start = params.sparams .logit_bias_strings_start .size ();
692733
693734 for (llama_token id = 0 ; id < llama_vocab_n_tokens (vocab); id++) {
694735 std::string token_str = common_token_to_piece (ctx, id);
695736
696- if (has_logit_biases) logit_bias_postfill (id, token_str);
697- if (has_logit_biases_start) start_bias_tokens_postfill (id, token_str);
737+ if (has_logit_biases_detailed == true && logit_bias_checks (token_str) == true ) {
738+ params.sparams .logit_bias .push_back ({id, -INFINITY});
739+ } else if (has_logit_biases == true ) {
740+ logit_bias_postfill (id, token_str);
741+ }
742+
743+ if (has_logit_biases_start == true ) start_bias_tokens_postfill (id, token_str);
698744 if (safeguard_token < 0 ) get_safeguard_token (id, token_str, safeguard_string);
699745 }
700746
@@ -1636,7 +1682,7 @@ class chat
16361682 if (id == l) {
16371683 checks = 0 ;
16381684 std::string c_restricted_tkn_string = common_token_to_piece (ctx, id);
1639- writeTextFile (" logit_biasing.txt" , std::format (" Found: '{}';" , c_restricted_tkn_string));
1685+ writeTextFile (" logit_biasing.txt" , std::format (" {}: Found '{}';" , params. sparams . seed , c_restricted_tkn_string));
16401686
16411687 id = common_sampler_shift (smpl, ctx, -1 , id);
16421688
@@ -1659,7 +1705,7 @@ class chat
16591705 // --attempts;
16601706 // }
16611707
1662- if (biased_logit.token == id) {
1708+ if (biased_logit.bias < - 9 && biased_logit. token == id) {
16631709 ++c_restricted_tkns;
16641710 // std::string c_restricted_tkn_string = common_token_to_piece(ctx, id);
16651711 // writeTextFile("logit_biasing.txt", std::format("+{}\n", c_restricted_tkn_string));
0 commit comments