Skip to content

Commit f51519e

Browse files
committed
Message start bias, selected token shift, biases fill rework
* `logit_bias_strings_start` json parameter allows to restrict tokens that appear at the beginning of generated messages * all biases are processes within one cycle through the vocab * `common_sampler_shift` function allows to shift selected token for external control * latest commits
1 parent 81bb71b commit f51519e

23 files changed

+877
-340
lines changed

base_sampling2/chat_layer.h

Lines changed: 166 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ class chat
270270
int n_embd_inp_last = 0;
271271
int c_empty_msgs = 0;
272272
int c_restricted_tkns = 0;
273+
int safeguard_token = -1;
273274

274275
// experimenting with simple dynamic paramenters
275276
float d_temp_min = 1.8;
@@ -286,13 +287,16 @@ class chat
286287
std::string txt_vocab_eos = "";
287288
std::string logit_bias_strings_display = "";
288289
std::string logit_bias_strings_ext_display = "";
290+
std::string logit_bias_strings_start_display = "";
289291

290292
struct llama_perf_context_data ctx_performance_data;
291293

292294
//std::map<std::string,std::string> stats;
293295

294296
ring_buffer<llama_token> prev_state = ring_buffer<llama_token>(std::max(32, params.sparams.n_prev));
295297

298+
std::vector<llama_token> logit_bias_tokens_start;
299+
296300
chat(int argc, char ** argv){
297301
init(argc, argv);
298302
}
@@ -551,7 +555,7 @@ class chat
551555
if (token_str_pos == 0 || token_str_pos == (word.size() - 1)) {
552556
restricted = true;
553557
break;
554-
} else if (token_str.find(word) == 0 && (token_str.length() - word.length() < 4)) {
558+
} else if (token_str.find(word) == 0 && (token_str.length() - word.length()) < 4) {
555559
restricted = true;
556560
break;
557561
}
@@ -592,44 +596,127 @@ class chat
592596
// std::getline(std::cin, pause);
593597
}
594598

595-
void sparams_postfill_ext() {
596-
// std::string space = " ";
597-
if (params.sparams.logit_bias_strings_ext.size()) {
598-
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
599-
std::string token_str = common_token_to_piece(ctx, i);
600-
// cutting spaces since there are "duplicated" tokens with them
601-
if (token_str.front() == ' ') {
602-
token_str = token_str.substr(1);
599+
void logit_bias_postfill(llama_token & id, std::string token_str) {
600+
// cutting spaces since there are "duplicated" tokens with them
601+
if (token_str.front() == ' ') {
602+
token_str = token_str.substr(1);
603+
}
604+
605+
// almost never happens
606+
if (token_str.back() == ' ') {
607+
token_str.pop_back();
608+
}
609+
610+
bool restricted = false;
611+
float bias = -INFINITY;
612+
613+
if (token_str.length() > 2) {
614+
for (auto word : params.sparams.logit_bias_strings) {
615+
auto token_str_pos = word.find(token_str);
616+
617+
// if (token_str_pos == 0 || token_str_pos == (word.size() - 1)) {
618+
if (token_str_pos == 0) {
619+
restricted = true;
620+
break;
621+
} else if (token_str.find(word) == 0 && token_str.length() >= word.length() && (token_str.length() - word.length()) < 4) {
622+
restricted = true;
623+
break;
603624
}
625+
}
604626

605-
// almost never happens
606-
if (token_str.back() == ' ') {
607-
token_str.pop_back();
627+
for (auto word_bias_pair : params.sparams.logit_bias_strings_ext) {
628+
auto token_str_pos = word_bias_pair.first.find(token_str);
629+
// if (token_str_pos == 0 || token_str_pos == (word_bias_pair.first.size() - 1) ) {
630+
if (token_str_pos == 0 || token_str_pos == (word_bias_pair.first.size() - 1) ) {
631+
restricted = true;
632+
bias = word_bias_pair.second;
633+
break;
634+
}
635+
}
636+
} else if (token_str.length() > 0) {
637+
for (auto word : params.sparams.logit_bias_strings) {
638+
if (token_str == word) {
639+
restricted = true;
640+
break;
608641
}
642+
}
609643

610-
if (token_str.length() > 1) {
611-
for (auto word_bias_pair : params.sparams.logit_bias_strings_ext) {
612-
auto token_str_pos = word_bias_pair.first.find(token_str);
613-
if (token_str_pos == 0 || token_str_pos == (word_bias_pair.first.size() - 1) ) {
614-
params.sparams.logit_bias.push_back({i, word_bias_pair.second});
615-
break;
616-
}
617-
}
644+
for (auto word_bias_pair : params.sparams.logit_bias_strings_ext) {
645+
if (token_str == word_bias_pair.first) {
646+
restricted = true;
647+
bias = word_bias_pair.second;
648+
break;
618649
}
619650
}
620651
}
652+
653+
if (restricted == true) {
654+
params.sparams.logit_bias.push_back({id, bias});
655+
}
656+
}
657+
658+
void start_bias_tokens_postfill(llama_token & id, std::string token_str) {
659+
if (token_str.front() == ' ') {
660+
token_str = token_str.substr(1);
661+
}
662+
663+
if (token_str.back() == ' ') {
664+
token_str.pop_back();
665+
}
666+
667+
for (auto word : params.sparams.logit_bias_strings_start) {
668+
if (word == token_str) {
669+
logit_bias_tokens_start.push_back(id);
670+
break;
671+
}
672+
}
673+
}
674+
675+
void get_safeguard_token(llama_token & id, std::string token_str, std::string_view safeguard_string) {
676+
if (token_str.front() == ' ') {
677+
token_str = token_str.substr(1);
678+
}
679+
680+
if (token_str.back() == ' ') {
681+
token_str.pop_back();
682+
}
683+
684+
if (token_str == safeguard_string) {
685+
safeguard_token = id;
686+
}
687+
}
688+
689+
void processByVocab(std::string safeguard_string) {
690+
bool has_logit_biases = (params.sparams.logit_bias_strings.size() || params.sparams.logit_bias_strings_ext.size());
691+
bool has_logit_biases_start = params.sparams.logit_bias_strings_start.size();
692+
693+
for (llama_token id = 0; id < llama_vocab_n_tokens(vocab); id++) {
694+
std::string token_str = common_token_to_piece(ctx, id);
695+
696+
if (has_logit_biases) logit_bias_postfill(id, token_str);
697+
if (has_logit_biases_start) start_bias_tokens_postfill(id, token_str);
698+
if (safeguard_token < 0) get_safeguard_token(id, token_str, safeguard_string);
699+
}
700+
701+
printf("%s: finihsed...\n", __func__);
621702
}
622703

623704
void get_logit_bias_str() {
624705
logit_bias_strings_display = "";
625706
logit_bias_strings_ext_display = "";
707+
logit_bias_strings_start_display = "";
708+
626709
for (auto l : params.sparams.logit_bias) {
627710
if (l.bias == -INFINITY) {
628711
logit_bias_strings_display += std::format(" '{}';", common_token_to_piece(ctx, l.token));
629712
} else {
630713
logit_bias_strings_ext_display += std::format(" '{}'={:.2f};", common_token_to_piece(ctx, l.token), l.bias);
631714
}
632715
}
716+
717+
for (auto l : logit_bias_tokens_start) {
718+
logit_bias_strings_start_display += std::format(" '{}';", common_token_to_piece(ctx, l));
719+
}
633720
}
634721

635722
void params_postfill() {
@@ -1055,23 +1142,23 @@ class chat
10551142
}
10561143

10571144
int load(bool soft = false){
1058-
1145+
10591146
printf("Load start \n");
1060-
1147+
10611148
auto & sparams = params.sparams;
10621149
// this function is only needed if backends are compiled as dynamic libraries
10631150
// there might be buffer problems for now
10641151
ggml_backend_load_all();
10651152
printf("..............Loaded dynamic backends.(%s)................\n", __func__);
1066-
1153+
10671154
// if (!soft){
10681155
// int status = 0;
10691156

10701157
// printf("strictLoad \n");
10711158
// status = strictLoad();
10721159
// if (status == 0) return 0;
10731160
// }
1074-
1161+
10751162
if (!soft){
10761163
int status = 0;
10771164

@@ -1146,8 +1233,10 @@ class chat
11461233
printf("%s: llama_n_ctx = %d\n", __func__, n_ctx);
11471234

11481235
// processing restricted words into logit_bias
1149-
sparams_postfill();
1236+
// sparams_postfill();
11501237
//sparams_postfill_ext();
1238+
// get_safeguard_token("Title");
1239+
processByVocab("Title");
11511240

11521241

11531242
smpl = common_sampler_init(model, sparams);
@@ -1502,7 +1591,7 @@ class chat
15021591
}
15031592

15041593
// main generation, includes adding antiprompt at the end
1505-
int sampleTknIntoEmbd(bool emptyMessage = false) {
1594+
int sampleTknIntoEmbd(bool emptyMessage = false, bool shortMessage = false) {
15061595
if (debug) printf("-ae");
15071596

15081597
// optionally save the session on first sample (for faster prompt loading next time)
@@ -1516,13 +1605,52 @@ class chat
15161605
llama_token id = common_sampler_sample(smpl, ctx, -1);
15171606

15181607
// try to sample a different token to avoid empty messages
1519-
while (emptyMessage == true && llama_token_is_eog(vocab, id)) {
1608+
int attempts = 1000; // safeguard
1609+
while (emptyMessage == true && llama_token_is_eog(vocab, id) && attempts > 0) {
15201610
++c_empty_msgs;
1521-
common_sampler_reset(smpl);
1522-
id = common_sampler_sample(smpl, ctx, -1);
1611+
--attempts;
1612+
// common_sampler_reset(smpl);
1613+
id = common_sampler_shift(smpl, ctx, -1, id);
1614+
}
1615+
1616+
for (auto l_b : params.sparams.logit_bias) {
1617+
if (l_b.bias < -99 && id == l_b.token) {
1618+
std::string c_bias_tkn_string = common_token_to_piece(ctx, id);
1619+
writeTextFile("logit_biasing.txt", std::format("Restricted: '{}';", c_bias_tkn_string));
1620+
1621+
id = common_sampler_shift(smpl, ctx, -1, id);
1622+
1623+
c_bias_tkn_string = common_token_to_piece(ctx, id);
1624+
writeTextFile("logit_biasing.txt", std::format(" replaced with: '{}'\n", c_bias_tkn_string));
1625+
}
15231626
}
15241627

1525-
// for (auto biased_logit : params.sparams.logit_bias) {
1628+
if (shortMessage) {
1629+
if (llama_token_is_eog(vocab, id)) {
1630+
id = safeguard_token;
1631+
} else {
1632+
int checks = 0;
1633+
while (checks < logit_bias_tokens_start.size()) {
1634+
for (auto l : logit_bias_tokens_start) {
1635+
++checks;
1636+
if (id == l) {
1637+
checks = 0;
1638+
std::string c_restricted_tkn_string = common_token_to_piece(ctx, id);
1639+
writeTextFile("logit_biasing.txt", std::format("Found: '{}';", c_restricted_tkn_string));
1640+
1641+
id = common_sampler_shift(smpl, ctx, -1, id);
1642+
1643+
c_restricted_tkn_string = common_token_to_piece(ctx, id);
1644+
writeTextFile("logit_biasing.txt", std::format(" replaced with: '{}'\n", c_restricted_tkn_string));
1645+
1646+
break;
1647+
}
1648+
}
1649+
}
1650+
}
1651+
}
1652+
1653+
for (auto biased_logit : params.sparams.logit_bias) {
15261654
// if (biased_logit.bias < 0 && biased_logit.token == id) {
15271655
// int attempts = 1000; // safeguard
15281656
// while (biased_logit.token == id && attempts > 0) {
@@ -1531,14 +1659,14 @@ class chat
15311659
// --attempts;
15321660
// }
15331661

1534-
// if (biased_logit.token == id) {
1535-
// ++c_restricted_tkns;
1662+
if (biased_logit.token == id) {
1663+
++c_restricted_tkns;
15361664
// std::string c_restricted_tkn_string = common_token_to_piece(ctx, id);
15371665
// writeTextFile("logit_biasing.txt", std::format("+{}\n", c_restricted_tkn_string));
15381666
// break;
1539-
// }
1667+
}
15401668
// }
1541-
// }
1669+
}
15421670

15431671
// accept the result
15441672
common_sampler_accept(smpl, id, /* apply_grammar= */ true);
@@ -2016,7 +2144,7 @@ class chat
20162144
// this is an attempt to strictly separate all input-based preparations
20172145
// however, it assumes conditions (see in getTokenOld())
20182146
// fromInpToEmbd() and capture_states() should be done elsewhere
2019-
std::string getBit(bool emptyMessage = false) { // 1 2 3 4
2147+
std::string getBit(bool emptyMessage = false, bool shortMessage = false) { // 1 2 3 4
20202148
//std::cout << " ** " << std::endl;
20212149
//log_down(std::format("processEmb: {} vs {}\n", embd_inp.size(), n_consumed), params.seed);
20222150

@@ -2026,19 +2154,19 @@ class chat
20262154
return txt_vocab_eos;
20272155
}
20282156

2029-
if (!is_interacting) sampleTknIntoEmbd(emptyMessage); // 2
2157+
if (!is_interacting) sampleTknIntoEmbd(emptyMessage, shortMessage); // 2
20302158

20312159
return getTknFromEmbd();
20322160
}
20332161

20342162
// token by token generation and pushing
2035-
std::string cycleStringsOnly(bool stream = false, bool emptyMessage = false) {
2163+
std::string cycleStringsOnly(bool emptyMessage = false, bool shortMessage = false) {
20362164

20372165
dynamicParamsPrepare();
20382166
//process_prompt(false); // do not forget to include it elsewhere after loading the model
20392167
//inputOnly(input); // MOVED
20402168

2041-
std::string bit = getBit(emptyMessage);
2169+
std::string bit = getBit(emptyMessage, shortMessage);
20422170

20432171
if ((int) std::size(embd_inp) <= n_consumed) {
20442172
if (debug) printf("-cso");

base_sampling2/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,10 @@ struct common_params_sampling {
194194
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
195195

196196
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
197+
197198
std::vector<std::string> logit_bias_strings; // words for logit biases
198199
std::map<std::string, float> logit_bias_strings_ext; // words for logit biases, but with extra configuration
200+
std::vector<std::string> logit_bias_strings_start; // restricted beginnings of messages
199201

200202
// print the parameters into a string
201203
std::string print() const;

base_sampling2/include/jsonParams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ static void getSamplingParamsFromJson(nlohmann::json& config, common_params& par
534534
// logit_bias_strings
535535
if (checkJArr(config, "logit_bias_strings")) params.sparams.logit_bias_strings = config["logit_bias_strings"];
536536
if (checkJObj(config, "logit_bias_strings_ext")) params.sparams.logit_bias_strings_ext = config["logit_bias_strings_ext"];
537+
if (checkJArr(config, "logit_bias_strings_start")) params.sparams.logit_bias_strings_start = config["logit_bias_strings_start"];
537538

538539
}
539540

base_sampling2/master/ggml/include/ggml.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,15 +528,15 @@ extern "C" {
528528
GGML_UNARY_OP_STEP,
529529
GGML_UNARY_OP_TANH,
530530
GGML_UNARY_OP_ELU,
531+
GGML_UNARY_OP_RELU,
531532
GGML_UNARY_OP_SIGMOID,
532533
GGML_UNARY_OP_GELU,
533-
GGML_UNARY_OP_GELU_ERF,
534534
GGML_UNARY_OP_GELU_QUICK,
535535
GGML_UNARY_OP_SILU,
536536
GGML_UNARY_OP_HARDSWISH,
537537
GGML_UNARY_OP_HARDSIGMOID,
538538
GGML_UNARY_OP_EXP,
539-
GGML_UNARY_OP_RELU,
539+
GGML_UNARY_OP_GELU_ERF,
540540

541541
GGML_UNARY_OP_COUNT,
542542
};

base_sampling2/master/ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3484,6 +3484,19 @@ void ggml_cpu_init(void) {
34843484
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
34853485

34863486
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
3487+
3488+
#ifdef GGML_USE_OPENMP
3489+
//if (!getenv("OMP_WAIT_POLICY")) {
3490+
// // set the wait policy to active, so that OpenMP threads don't sleep
3491+
// putenv("OMP_WAIT_POLICY=active");
3492+
//}
3493+
3494+
if (!getenv("KMP_BLOCKTIME")) {
3495+
// set the time to wait before sleeping a thread
3496+
// this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
3497+
putenv("KMP_BLOCKTIME=200"); // 200ms
3498+
}
3499+
#endif
34873500
}
34883501

34893502
#if defined(__ARM_ARCH)

0 commit comments

Comments
 (0)