@@ -448,6 +448,15 @@ void string_replace_all(std::string & s, const std::string & search, const std::
448448bool string_ends_with (const std::string_view & str, const std::string_view & suffix) {
449449 return str.size () >= suffix.size () && str.compare (str.size ()-suffix.size (), suffix.size (), suffix) == 0 ;
450450}
451+
452+ bool string_remove_suffix (std::string & str, const std::string_view & suffix) {
453+ bool has_suffix = string_ends_with (str, suffix);
454+ if (has_suffix) {
455+ str = str.substr (0 , str.size () - suffix.size ());
456+ }
457+ return has_suffix;
458+ }
459+
451460size_t string_find_partial_stop (const std::string_view & str, const std::string_view & stop) {
452461 if (!str.empty () && !stop.empty ()) {
453462 const char text_last_char = str.back ();
@@ -1019,15 +1028,21 @@ struct common_init_result common_init_from_params(common_params & params) {
10191028 params.sampling .ignore_eos = false ;
10201029 }
10211030
1022- if (params.sampling .ignore_eos ) {
1023- for (llama_token i = 0 ; i < llama_vocab_n_tokens (vocab); i++) {
1024- if (llama_vocab_is_eog (vocab, i)) {
1025- LOG_INF (" %s: added %s logit bias = %f\n " , __func__, common_token_to_piece (lctx, i).c_str (), -INFINITY);
1026- params.sampling .logit_bias .push_back ({i, -INFINITY});
1027- }
1031+ // initialize once
1032+ for (llama_token i = 0 ; i < llama_vocab_n_tokens (vocab); i++) {
1033+ if (llama_vocab_is_eog (vocab, i)) {
1034+ LOG_INF (" %s: added %s logit bias = %f\n " , __func__, common_token_to_piece (lctx, i).c_str (), -INFINITY);
1035+ params.sampling .logit_bias_eog .push_back ({i, -INFINITY});
10281036 }
10291037 }
10301038
1039+ if (params.sampling .ignore_eos ) {
1040+ // add EOG biases to the active set of logit biases
1041+ params.sampling .logit_bias .insert (
1042+ params.sampling .logit_bias .end (),
1043+ params.sampling .logit_bias_eog .begin (), params.sampling .logit_bias_eog .end ());
1044+ }
1045+
10311046 if (params.sampling .penalty_last_n == -1 ) {
10321047 LOG_INF (" %s: setting penalty_last_n to ctx_size = %d\n " , __func__, llama_n_ctx (lctx));
10331048 params.sampling .penalty_last_n = llama_n_ctx (lctx);
@@ -1171,6 +1186,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11711186 cparams.no_perf = params.no_perf ;
11721187 cparams.op_offload = !params.no_op_offload ;
11731188 cparams.swa_full = params.swa_full ;
1189+ cparams.kv_unified = params.kv_unified ;
11741190
11751191 cparams.type_k = params.cache_type_k ;
11761192 cparams.type_v = params.cache_type_v ;
0 commit comments