@@ -270,6 +270,7 @@ class chat
270270 int n_embd_inp_last = 0 ;
271271 int c_empty_msgs = 0 ;
272272 int c_restricted_tkns = 0 ;
273+ int safeguard_token = -1 ;
273274
274275 // experimenting with simple dynamic paramenters
275276 float d_temp_min = 1.8 ;
@@ -286,13 +287,16 @@ class chat
286287 std::string txt_vocab_eos = " " ;
287288 std::string logit_bias_strings_display = " " ;
288289 std::string logit_bias_strings_ext_display = " " ;
290+ std::string logit_bias_strings_start_display = " " ;
289291
290292 struct llama_perf_context_data ctx_performance_data;
291293
292294 // std::map<std::string,std::string> stats;
293295
294296 ring_buffer<llama_token> prev_state = ring_buffer<llama_token>(std::max(32 , params.sparams.n_prev));
295297
298+ std::vector<llama_token> logit_bias_tokens_start;
299+
296300 chat (int argc, char ** argv){
297301 init (argc, argv);
298302 }
@@ -551,7 +555,7 @@ class chat
551555 if (token_str_pos == 0 || token_str_pos == (word.size () - 1 )) {
552556 restricted = true ;
553557 break ;
554- } else if (token_str.find (word) == 0 && (token_str.length () - word.length () < 4 ) ) {
558+ } else if (token_str.find (word) == 0 && (token_str.length () - word.length ()) < 4 ) {
555559 restricted = true ;
556560 break ;
557561 }
@@ -592,44 +596,127 @@ class chat
592596 // std::getline(std::cin, pause);
593597 }
594598
595- void sparams_postfill_ext () {
596- // std::string space = " ";
597- if (params.sparams .logit_bias_strings_ext .size ()) {
598- for (llama_token i = 0 ; i < llama_vocab_n_tokens (vocab); i++) {
599- std::string token_str = common_token_to_piece (ctx, i);
600- // cutting spaces since there are "duplicated" tokens with them
601- if (token_str.front () == ' ' ) {
602- token_str = token_str.substr (1 );
599+ void logit_bias_postfill (llama_token & id, std::string token_str) {
600+ // cutting spaces since there are "duplicated" tokens with them
601+ if (token_str.front () == ' ' ) {
602+ token_str = token_str.substr (1 );
603+ }
604+
605+ // almost never happens
606+ if (token_str.back () == ' ' ) {
607+ token_str.pop_back ();
608+ }
609+
610+ bool restricted = false ;
611+ float bias = -INFINITY;
612+
613+ if (token_str.length () > 2 ) {
614+ for (auto word : params.sparams .logit_bias_strings ) {
615+ auto token_str_pos = word.find (token_str);
616+
617+ // if (token_str_pos == 0 || token_str_pos == (word.size() - 1)) {
618+ if (token_str_pos == 0 ) {
619+ restricted = true ;
620+ break ;
621+ } else if (token_str.find (word) == 0 && token_str.length () >= word.length () && (token_str.length () - word.length ()) < 4 ) {
622+ restricted = true ;
623+ break ;
603624 }
625+ }
604626
605- // almost never happens
606- if (token_str.back () == ' ' ) {
607- token_str.pop_back ();
627+ for (auto word_bias_pair : params.sparams .logit_bias_strings_ext ) {
628+ auto token_str_pos = word_bias_pair.first .find (token_str);
629+ // if (token_str_pos == 0 || token_str_pos == (word_bias_pair.first.size() - 1) ) {
630+ if (token_str_pos == 0 || token_str_pos == (word_bias_pair.first .size () - 1 ) ) {
631+ restricted = true ;
632+ bias = word_bias_pair.second ;
633+ break ;
634+ }
635+ }
636+ } else if (token_str.length () > 0 ) {
637+ for (auto word : params.sparams .logit_bias_strings ) {
638+ if (token_str == word) {
639+ restricted = true ;
640+ break ;
608641 }
642+ }
609643
610- if (token_str.length () > 1 ) {
611- for (auto word_bias_pair : params.sparams .logit_bias_strings_ext ) {
612- auto token_str_pos = word_bias_pair.first .find (token_str);
613- if (token_str_pos == 0 || token_str_pos == (word_bias_pair.first .size () - 1 ) ) {
614- params.sparams .logit_bias .push_back ({i, word_bias_pair.second });
615- break ;
616- }
617- }
644+ for (auto word_bias_pair : params.sparams .logit_bias_strings_ext ) {
645+ if (token_str == word_bias_pair.first ) {
646+ restricted = true ;
647+ bias = word_bias_pair.second ;
648+ break ;
618649 }
619650 }
620651 }
652+
653+ if (restricted == true ) {
654+ params.sparams .logit_bias .push_back ({id, bias});
655+ }
656+ }
657+
658+ void start_bias_tokens_postfill (llama_token & id, std::string token_str) {
659+ if (token_str.front () == ' ' ) {
660+ token_str = token_str.substr (1 );
661+ }
662+
663+ if (token_str.back () == ' ' ) {
664+ token_str.pop_back ();
665+ }
666+
667+ for (auto word : params.sparams .logit_bias_strings_start ) {
668+ if (word == token_str) {
669+ logit_bias_tokens_start.push_back (id);
670+ break ;
671+ }
672+ }
673+ }
674+
675+ void get_safeguard_token (llama_token & id, std::string token_str, std::string_view safeguard_string) {
676+ if (token_str.front () == ' ' ) {
677+ token_str = token_str.substr (1 );
678+ }
679+
680+ if (token_str.back () == ' ' ) {
681+ token_str.pop_back ();
682+ }
683+
684+ if (token_str == safeguard_string) {
685+ safeguard_token = id;
686+ }
687+ }
688+
689+ void processByVocab (std::string safeguard_string) {
690+ bool has_logit_biases = (params.sparams .logit_bias_strings .size () || params.sparams .logit_bias_strings_ext .size ());
691+ bool has_logit_biases_start = params.sparams .logit_bias_strings_start .size ();
692+
693+ for (llama_token id = 0 ; id < llama_vocab_n_tokens (vocab); id++) {
694+ std::string token_str = common_token_to_piece (ctx, id);
695+
696+ if (has_logit_biases) logit_bias_postfill (id, token_str);
697+ if (has_logit_biases_start) start_bias_tokens_postfill (id, token_str);
698+ if (safeguard_token < 0 ) get_safeguard_token (id, token_str, safeguard_string);
699+ }
700+
701+ printf (" %s: finihsed...\n " , __func__);
621702 }
622703
623704 void get_logit_bias_str () {
624705 logit_bias_strings_display = " " ;
625706 logit_bias_strings_ext_display = " " ;
707+ logit_bias_strings_start_display = " " ;
708+
626709 for (auto l : params.sparams .logit_bias ) {
627710 if (l.bias == -INFINITY) {
628711 logit_bias_strings_display += std::format (" '{}';" , common_token_to_piece (ctx, l.token ));
629712 } else {
630713 logit_bias_strings_ext_display += std::format (" '{}'={:.2f};" , common_token_to_piece (ctx, l.token ), l.bias );
631714 }
632715 }
716+
717+ for (auto l : logit_bias_tokens_start) {
718+ logit_bias_strings_start_display += std::format (" '{}';" , common_token_to_piece (ctx, l));
719+ }
633720 }
634721
635722 void params_postfill () {
@@ -1055,23 +1142,23 @@ class chat
10551142 }
10561143
10571144 int load (bool soft = false ){
1058-
1145+
10591146 printf (" Load start \n " );
1060-
1147+
10611148 auto & sparams = params.sparams ;
10621149 // this function is only needed if backends are compiled as dynamic libraries
10631150 // there might be buffer problems for now
10641151 ggml_backend_load_all ();
10651152 printf (" ..............Loaded dynamic backends.(%s)................\n " , __func__);
1066-
1153+
10671154 // if (!soft){
10681155 // int status = 0;
10691156
10701157 // printf("strictLoad \n");
10711158 // status = strictLoad();
10721159 // if (status == 0) return 0;
10731160 // }
1074-
1161+
10751162 if (!soft){
10761163 int status = 0 ;
10771164
@@ -1146,8 +1233,10 @@ class chat
11461233 printf (" %s: llama_n_ctx = %d\n " , __func__, n_ctx);
11471234
11481235 // processing restricted words into logit_bias
1149- sparams_postfill ();
1236+ // sparams_postfill();
11501237 // sparams_postfill_ext();
1238+ // get_safeguard_token("Title");
1239+ processByVocab (" Title" );
11511240
11521241
11531242 smpl = common_sampler_init (model, sparams);
@@ -1502,7 +1591,7 @@ class chat
15021591 }
15031592
15041593 // main generation, includes adding antiprompt at the end
1505- int sampleTknIntoEmbd (bool emptyMessage = false ) {
1594+ int sampleTknIntoEmbd (bool emptyMessage = false , bool shortMessage = false ) {
15061595 if (debug) printf (" -ae" );
15071596
15081597 // optionally save the session on first sample (for faster prompt loading next time)
@@ -1516,13 +1605,52 @@ class chat
15161605 llama_token id = common_sampler_sample (smpl, ctx, -1 );
15171606
15181607 // try to sample a different token to avoid empty messages
1519- while (emptyMessage == true && llama_token_is_eog (vocab, id)) {
1608+ int attempts = 1000 ; // safeguard
1609+ while (emptyMessage == true && llama_token_is_eog (vocab, id) && attempts > 0 ) {
15201610 ++c_empty_msgs;
1521- common_sampler_reset (smpl);
1522- id = common_sampler_sample (smpl, ctx, -1 );
1611+ --attempts;
1612+ // common_sampler_reset(smpl);
1613+ id = common_sampler_shift (smpl, ctx, -1 , id);
1614+ }
1615+
1616+ for (auto l_b : params.sparams .logit_bias ) {
1617+ if (l_b.bias < -99 && id == l_b.token ) {
1618+ std::string c_bias_tkn_string = common_token_to_piece (ctx, id);
1619+ writeTextFile (" logit_biasing.txt" , std::format (" Restricted: '{}';" , c_bias_tkn_string));
1620+
1621+ id = common_sampler_shift (smpl, ctx, -1 , id);
1622+
1623+ c_bias_tkn_string = common_token_to_piece (ctx, id);
1624+ writeTextFile (" logit_biasing.txt" , std::format (" replaced with: '{}'\n " , c_bias_tkn_string));
1625+ }
15231626 }
15241627
1525- // for (auto biased_logit : params.sparams.logit_bias) {
1628+ if (shortMessage) {
1629+ if (llama_token_is_eog (vocab, id)) {
1630+ id = safeguard_token;
1631+ } else {
1632+ int checks = 0 ;
1633+ while (checks < logit_bias_tokens_start.size ()) {
1634+ for (auto l : logit_bias_tokens_start) {
1635+ ++checks;
1636+ if (id == l) {
1637+ checks = 0 ;
1638+ std::string c_restricted_tkn_string = common_token_to_piece (ctx, id);
1639+ writeTextFile (" logit_biasing.txt" , std::format (" Found: '{}';" , c_restricted_tkn_string));
1640+
1641+ id = common_sampler_shift (smpl, ctx, -1 , id);
1642+
1643+ c_restricted_tkn_string = common_token_to_piece (ctx, id);
1644+ writeTextFile (" logit_biasing.txt" , std::format (" replaced with: '{}'\n " , c_restricted_tkn_string));
1645+
1646+ break ;
1647+ }
1648+ }
1649+ }
1650+ }
1651+ }
1652+
1653+ for (auto biased_logit : params.sparams .logit_bias ) {
15261654 // if (biased_logit.bias < 0 && biased_logit.token == id) {
15271655 // int attempts = 1000; // safeguard
15281656 // while (biased_logit.token == id && attempts > 0) {
@@ -1531,14 +1659,14 @@ class chat
15311659 // --attempts;
15321660 // }
15331661
1534- // if (biased_logit.token == id) {
1535- // ++c_restricted_tkns;
1662+ if (biased_logit.token == id) {
1663+ ++c_restricted_tkns;
15361664 // std::string c_restricted_tkn_string = common_token_to_piece(ctx, id);
15371665 // writeTextFile("logit_biasing.txt", std::format("+{}\n", c_restricted_tkn_string));
15381666 // break;
1539- // }
1667+ }
15401668 // }
1541- // }
1669+ }
15421670
15431671 // accept the result
15441672 common_sampler_accept (smpl, id, /* apply_grammar= */ true );
@@ -2016,7 +2144,7 @@ class chat
20162144// this is an attempt to strictly separate all input-based preparations
20172145// however, it assumes conditions (see in getTokenOld())
20182146// fromInpToEmbd() and capture_states() should be done elsewhere
2019- std::string getBit (bool emptyMessage = false ) { // 1 2 3 4
2147+ std::string getBit (bool emptyMessage = false , bool shortMessage = false ) { // 1 2 3 4
20202148 // std::cout << " ** " << std::endl;
20212149 // log_down(std::format("processEmb: {} vs {}\n", embd_inp.size(), n_consumed), params.seed);
20222150
@@ -2026,19 +2154,19 @@ class chat
20262154 return txt_vocab_eos;
20272155 }
20282156
2029- if (!is_interacting) sampleTknIntoEmbd (emptyMessage); // 2
2157+ if (!is_interacting) sampleTknIntoEmbd (emptyMessage, shortMessage ); // 2
20302158
20312159 return getTknFromEmbd ();
20322160 }
20332161
20342162 // token by token generation and pushing
2035- std::string cycleStringsOnly (bool stream = false , bool emptyMessage = false ) {
2163+ std::string cycleStringsOnly (bool emptyMessage = false , bool shortMessage = false ) {
20362164
20372165 dynamicParamsPrepare ();
20382166 // process_prompt(false); // do not forget to include it elsewhere after loading the model
20392167 // inputOnly(input); // MOVED
20402168
2041- std::string bit = getBit (emptyMessage);
2169+ std::string bit = getBit (emptyMessage, shortMessage );
20422170
20432171 if ((int ) std::size (embd_inp) <= n_consumed) {
20442172 if (debug) printf (" -cso" );
0 commit comments