Skip to content

Commit 8f23597

Browse files
Update common-mmojo.cpp
Signed-off-by: Brad Hutchings <[email protected]>
1 parent 40abf07 commit 8f23597

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

common/common-mmojo.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,15 @@ void string_replace_all(std::string & s, const std::string & search, const std::
448448
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
449449
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
450450
}
451+
452+
bool string_remove_suffix(std::string & str, const std::string_view & suffix) {
453+
bool has_suffix = string_ends_with(str, suffix);
454+
if (has_suffix) {
455+
str = str.substr(0, str.size() - suffix.size());
456+
}
457+
return has_suffix;
458+
}
459+
451460
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
452461
if (!str.empty() && !stop.empty()) {
453462
const char text_last_char = str.back();
@@ -1019,15 +1028,21 @@ struct common_init_result common_init_from_params(common_params & params) {
10191028
params.sampling.ignore_eos = false;
10201029
}
10211030

1022-
if (params.sampling.ignore_eos) {
1023-
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
1024-
if (llama_vocab_is_eog(vocab, i)) {
1025-
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
1026-
params.sampling.logit_bias.push_back({i, -INFINITY});
1027-
}
1031+
// initialize once
1032+
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
1033+
if (llama_vocab_is_eog(vocab, i)) {
1034+
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
1035+
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
10281036
}
10291037
}
10301038

1039+
if (params.sampling.ignore_eos) {
1040+
// add EOG biases to the active set of logit biases
1041+
params.sampling.logit_bias.insert(
1042+
params.sampling.logit_bias.end(),
1043+
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
1044+
}
1045+
10311046
if (params.sampling.penalty_last_n == -1) {
10321047
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
10331048
params.sampling.penalty_last_n = llama_n_ctx(lctx);
@@ -1171,6 +1186,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11711186
cparams.no_perf = params.no_perf;
11721187
cparams.op_offload = !params.no_op_offload;
11731188
cparams.swa_full = params.swa_full;
1189+
cparams.kv_unified = params.kv_unified;
11741190

11751191
cparams.type_k = params.cache_type_k;
11761192
cparams.type_v = params.cache_type_v;

0 commit comments

Comments
 (0)