@@ -425,6 +425,29 @@ static void prompt_init(llama_tokens & prompt, const llama_model * model) {
425425 prompt_add (prompt, model, " <|im_start|>\n " , true , true );
426426}
427427
428+ static std::vector<llama_token> prepare_guide_tokens (const llama_model * model, const std::string& str)
429+ {
430+ const std::string& delimiter = " <|text_sep|>" ;
431+
432+ std::vector<llama_token> result;
433+ size_t start = 0 ;
434+ size_t end = str.find (delimiter);
435+
436+ while (end != std::string::npos) {
437+ std::string current_word = str.substr (start, end - start);
438+ auto tmp = common_tokenize (model, current_word, false , true );
439+ result.push_back (tmp[0 ]);
440+ start = end + delimiter.length ();
441+ end = str.find (delimiter, start);
442+ }
443+
444+ // Add the last part
445+ std::string current_word = str.substr (start);
446+ auto tmp = common_tokenize (model, current_word, false , true );
447+ result.push_back (tmp[0 ]);
448+ return result;
449+ }
450+
428451int main (int argc, char ** argv) {
429452 common_params params;
430453
@@ -492,6 +515,7 @@ int main(int argc, char ** argv) {
492515 const auto t_main_start = ggml_time_us ();
493516
494517 std::vector<llama_token> codes;
518+ std::vector<llama_token> guide_tokens;
495519
496520 // process prompt and generate voice codes
497521 {
@@ -506,6 +530,10 @@ int main(int argc, char ** argv) {
506530 // convert the input text into the necessary format expected by OuteTTS
507531 {
508532 std::string prompt_clean = process_text (params.prompt );
533+ if (params.vocoder .use_guide_tokens )
534+ {
535+ guide_tokens = prepare_guide_tokens (model_ttc,prompt_clean);
536+ }
509537
510538 LOG_INF (" %s: prompt: '%s'\n " , __func__, prompt_clean.c_str ());
511539
@@ -715,6 +743,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
715743 int n_past = batch.n_tokens ;
716744 int n_decode = 0 ;
717745
746+ bool next_token_uses_guide_token = true ;
747+
718748 while (n_decode <= n_predict) {
719749 // prepare the next batch
720750 common_batch_clear (batch);
@@ -726,7 +756,18 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
726756 continue ;
727757 }
728758
729- const llama_token new_token_id = common_sampler_sample (smpl[i], ctx_ttc, i_batch[i]);
759+ llama_token new_token_id = common_sampler_sample (smpl[i], ctx_ttc, i_batch[i]);
760+
761+ // guide tokens help prevent hallucinations by forcing the TTS to use the correct word
762+ if (!guide_tokens.empty () && next_token_uses_guide_token && !llama_token_is_control (model_ttc, new_token_id) && !llama_token_is_eog (model_ttc, new_token_id))
763+ {
764+ llama_token guide_token = guide_tokens[0 ];
765+ guide_tokens.erase (guide_tokens.begin ());
766+ new_token_id = guide_token; // ensure correct word fragment is used
767+ }
768+
769+ // this is the token id that always precedes a new word
770+ next_token_uses_guide_token = (new_token_id == 198 );
730771
731772 common_sampler_accept (smpl[i], new_token_id, true );
732773
0 commit comments