@@ -291,9 +291,12 @@ class chat
291291 std::string logit_bias_strings_display = " " ;
292292 std::string logit_bias_strings_ext_display = " " ;
293293 std::string logit_bias_strings_start_display = " " ;
294+ std::string logit_bias_strings_manual_display = " " ;
294295
295296 std::string last_candidates_logits_display = " " ;
296297
298+ std::string dry_sequence_breakers_display = " " ;
299+
297300 struct llama_perf_context_data ctx_performance_data;
298301
299302 // std::map<std::string,std::string> stats;
@@ -601,6 +604,57 @@ class chat
601604 // std::getline(std::cin, pause);
602605 }
603606
607+ void sparams_postfill2 () {
608+ // std::string space = " ";
609+ if (params.sparams .logit_bias_strings_manual .size ()) {
610+ for (llama_token i = 0 ; i < llama_vocab_n_tokens (vocab); i++) {
611+ std::string token_str = common_token_to_piece (ctx, i);
612+ // cutting spaces since there are "duplicated" tokens with them
613+ if (token_str.front () == ' ' ) {
614+ token_str = token_str.substr (1 );
615+ }
616+
617+ // almost never happens
618+ if (token_str.back () == ' ' ) {
619+ token_str.pop_back ();
620+ }
621+
622+ bool restricted = false ;
623+ float bias = -INFINITY;
624+
625+ if (token_str.length () > 2 ) {
626+ for (auto word : params.sparams .logit_bias_strings_manual ) {
627+ auto token_str_pos = word.find (token_str);
628+
629+ if (token_str_pos == 0 || token_str_pos == (word.size () - 1 )) {
630+ restricted = true ;
631+ break ;
632+ } else if (token_str.find (word) == 0 && (token_str.length () - word.length ()) < 4 ) {
633+ restricted = true ;
634+ break ;
635+ }
636+ }
637+ } else if (token_str.length () > 0 ) {
638+ for (auto word : params.sparams .logit_bias_strings_manual ) {
639+ if (token_str == word) {
640+ restricted = true ;
641+ break ;
642+ }
643+ }
644+ }
645+
646+ if (restricted == true ) {
647+ params.sparams .logit_bias_tokens_manual .push_back (i);
648+ }
649+ }
650+ }
651+
652+ // std::string pause = "";
653+ // std::getline(std::cin, pause);
654+ }
655+
656+
657+
604658 bool logit_bias_check_exact (std::string_view token_str) {
605659 for (auto word : params.sparams .logit_bias_strings_exact ) {
606660 if (token_str == word) return true ;
@@ -757,6 +811,7 @@ class chat
757811 logit_bias_strings_display = " " ;
758812 logit_bias_strings_ext_display = " " ;
759813 logit_bias_strings_start_display = " " ;
814+ logit_bias_strings_manual_display = " " ;
760815
761816 for (auto l : params.sparams .logit_bias ) {
762817 if (l.bias == -INFINITY) {
@@ -769,6 +824,10 @@ class chat
769824 for (auto l : logit_bias_tokens_start) {
770825 logit_bias_strings_start_display += std::format (" '{}';" , common_token_to_piece (ctx, l));
771826 }
827+
828+ for (auto l : params.sparams .logit_bias_tokens_manual ) {
829+ logit_bias_strings_manual_display += std::format (" '{}';" , common_token_to_piece (ctx, l));
830+ }
772831 }
773832
774833 void get_last_candidates_logits_display () {
@@ -779,6 +838,14 @@ class chat
779838 }
780839 }
781840
841+ void get_dry_sequence_breakers_display () {
842+ dry_sequence_breakers_display.clear ();
843+
844+ for (auto breaker : params.sparams .dry_sequence_breakers ) {
845+ dry_sequence_breakers_display += std::format (" {}; " , breaker);
846+ }
847+ }
848+
782849 void params_postfill () {
783850 if (params.kv_overrides_pair .size ()) kv_override_prefill ();
784851 common_process_override_tensors (params);
@@ -1296,11 +1363,12 @@ class chat
12961363 printf (" %s: llama_n_ctx = %d\n " , __func__, n_ctx);
12971364
12981365 // processing restricted words into logit_bias
1299- // sparams_postfill ();
1366+ sparams_postfill2 ();
13001367 // sparams_postfill_ext();
13011368 // get_safeguard_token("Title");
13021369 processByVocab (" Title" );
1303-
1370+ get_logit_bias_str ();
1371+ get_dry_sequence_breakers_display ();
13041372
13051373 smpl = common_sampler_init (model, sparams);
13061374 printf (" %s: common_sampler_init\n " , __func__);
@@ -1611,6 +1679,7 @@ class chat
16111679 void check_antiprompt_tkns () {
16121680 // check for reverse prompt using special tokens
16131681 llama_token last_token = common_sampler_last (smpl);
1682+
16141683 for (std::vector<llama_token> ids : antiprompt_ids) {
16151684 if (std::size (ids) == 1 && last_token == ids[0 ]) {
16161685 if (params.interactive ) {
@@ -1623,6 +1692,24 @@ class chat
16231692 }
16241693 }
16251694
1695+ bool check_antiprompt_tkns_bool () {
1696+ // check for reverse prompt using special tokens
1697+ llama_token last_token = common_sampler_last (smpl);
1698+
1699+ for (std::vector<llama_token> ids : antiprompt_ids) {
1700+ if (std::size (ids) == 1 && last_token == ids[0 ]) {
1701+ if (params.interactive ) {
1702+ is_interacting = true ;
1703+ has_antiprompt = std::format (" {}: already has antiprompt" , __func__);
1704+ }
1705+ is_antiprompt = true ;
1706+ return true ;
1707+ }
1708+ }
1709+
1710+ return false ;
1711+ }
1712+
16261713 // checking already existing contex
16271714 int checkEmbd (){
16281715 if (debug) printf (" -ce" );
@@ -1678,15 +1765,19 @@ class chat
16781765 id = common_sampler_shift (smpl, ctx, -1 , id);
16791766 }
16801767
1681- for (auto l_b : params.sparams .logit_bias ) {
1682- if (l_b.bias < -99 && id == l_b.token ) {
1683- std::string c_bias_tkn_string = common_token_to_piece (ctx, id);
1684- writeTextFile (" logit_biasing.txt" , std::format (" Restricted: '{}';" , c_bias_tkn_string));
1768+ int checks = 0 ;
1769+ while (checks < params.sparams .logit_bias_tokens_manual .size ()) {
1770+ for (auto tkn : params.sparams .logit_bias_tokens_manual ) {
1771+ ++checks;
1772+ if (id == tkn) {
1773+ std::string c_bias_tkn_string = common_token_to_piece (ctx, id);
1774+ writeTextFile (" logit_biasing.txt" , std::format (" {}: Restricted: '{}';" , params.sparams .seed , c_bias_tkn_string));
16851775
1686- id = common_sampler_shift (smpl, ctx, -1 , id);
1776+ id = common_sampler_shift (smpl, ctx, -1 , id);
16871777
1688- c_bias_tkn_string = common_token_to_piece (ctx, id);
1689- writeTextFile (" logit_biasing.txt" , std::format (" replaced with: '{}'\n " , c_bias_tkn_string));
1778+ c_bias_tkn_string = common_token_to_piece (ctx, id);
1779+ writeTextFile (" logit_biasing.txt" , std::format (" replaced with: '{}'\n " , c_bias_tkn_string));
1780+ }
16901781 }
16911782 }
16921783
@@ -2009,8 +2100,6 @@ class chat
20092100
20102101 if (debug) printf (" Starting initial prompt processing...\n " );
20112102
2012- get_logit_bias_str ();
2013-
20142103
20152104 std::string result;
20162105 // std::cout << " * " << std::endl;
@@ -2075,9 +2164,9 @@ class chat
20752164 const std::string getTknFromEmbd (){
20762165 if (debug) printf (" -gp" );
20772166
2078- for (auto id : embd) {
2079- // return llama_token_to_string(ctx, id);
2080- return common_token_to_piece (ctx, id);
2167+ for (auto id : embd) {
2168+ // return llama_token_to_string(ctx, id);
2169+ return common_token_to_piece (ctx, id);
20812170 }
20822171 }
20832172
@@ -2224,14 +2313,36 @@ class chat
22242313 return getTknFromEmbd ();
22252314 }
22262315
2316+ std::string getMultiBit (int numTkns = 2 , bool emptyMessage = false , bool shortMessage = false ) { // 1 2 3 4
2317+ std::string result = " " ;
2318+
2319+ for (int i = 0 ; i < numTkns; i++) {
2320+ if (checkAndClearEmbd () == 0 ) {
2321+ finished = true ;
2322+ return txt_vocab_eos;
2323+ }
2324+
2325+ if (!is_interacting) sampleTknIntoEmbd (emptyMessage, shortMessage); // 2
2326+
2327+ result += getTknFromEmbd ();
2328+
2329+ if (llama_token_is_eog (vocab, common_sampler_last (smpl))) {
2330+ return result;
2331+ }
2332+ }
2333+
2334+ return result;
2335+ }
2336+
22272337 // token by token generation and pushing
22282338 std::string cycleStringsOnly (bool emptyMessage = false , bool shortMessage = false ) {
22292339
22302340 dynamicParamsPrepare ();
22312341 // process_prompt(false); // do not forget to include it elsewhere after loading the model
22322342 // inputOnly(input); // MOVED
22332343
2234- std::string bit = getBit (emptyMessage, shortMessage);
2344+ // std::string bit = getBit(emptyMessage, shortMessage);
2345+ std::string bit = getMultiBit (2 , emptyMessage, shortMessage);
22352346
22362347 if ((int ) std::size (embd_inp) <= n_consumed) {
22372348 if (debug) printf (" -cso" );
0 commit comments