@@ -425,6 +425,33 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
425425 prompt_add (prompt, vocab, " <|im_start|>\n " , true , true );
426426}
427427
428+ static std::vector<llama_token> prepare_guide_tokens (const llama_vocab * vocab, const std::string & str) {
429+ const std::string& delimiter = " <|text_sep|>" ;
430+
431+ std::vector<llama_token> result;
432+ size_t start = 0 ;
433+ size_t end = str.find (delimiter);
434+
435+ // first token is always a newline, as it was not previously added
436+ result.push_back (common_tokenize (vocab, " \n " , false , true )[0 ]);
437+
438+ while (end != std::string::npos) {
439+ std::string current_word = str.substr (start, end - start);
440+ auto tmp = common_tokenize (vocab, current_word, false , true );
441+ result.push_back (tmp[0 ]);
442+ start = end + delimiter.length ();
443+ end = str.find (delimiter, start);
444+ }
445+
446+ // Add the last part
447+ std::string current_word = str.substr (start);
448+ auto tmp = common_tokenize (vocab, current_word, false , true );
449+ if (tmp.size () > 0 ) {
450+ result.push_back (tmp[0 ]);
451+ }
452+ return result;
453+ }
454+
428455int main (int argc, char ** argv) {
429456 common_params params;
430457
@@ -494,6 +521,7 @@ int main(int argc, char ** argv) {
494521 const auto t_main_start = ggml_time_us ();
495522
496523 std::vector<llama_token> codes;
524+ std::vector<llama_token> guide_tokens;
497525
498526 // process prompt and generate voice codes
499527 {
@@ -508,6 +536,9 @@ int main(int argc, char ** argv) {
508536 // convert the input text into the necessary format expected by OuteTTS
509537 {
510538 std::string prompt_clean = process_text (params.prompt );
539+ if (params.vocoder .use_guide_tokens ) {
540+ guide_tokens = prepare_guide_tokens (vocab, prompt_clean);
541+ }
511542
512543 LOG_INF (" %s: prompt: '%s'\n " , __func__, prompt_clean.c_str ());
513544
@@ -717,6 +748,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
717748 int n_past = batch.n_tokens ;
718749 int n_decode = 0 ;
719750
751+ bool next_token_uses_guide_token = true ;
752+
720753 while (n_decode <= n_predict) {
721754 // prepare the next batch
722755 common_batch_clear (batch);
@@ -728,7 +761,17 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
728761 continue ;
729762 }
730763
731- const llama_token new_token_id = common_sampler_sample (smpl[i], ctx_ttc, i_batch[i]);
764+ llama_token new_token_id = common_sampler_sample (smpl[i], ctx_ttc, i_batch[i]);
765+
766+ // guide tokens help prevent hallucinations by forcing the TTS to use the correct word
767+ if (!guide_tokens.empty () && next_token_uses_guide_token && !llama_vocab_is_control (vocab, new_token_id) && !llama_vocab_is_eog (vocab, new_token_id)) {
768+ llama_token guide_token = guide_tokens[0 ];
769+ guide_tokens.erase (guide_tokens.begin ());
770+ new_token_id = guide_token; // ensure correct word fragment is used
771+ }
772+
773+ // this is the token id that always precedes a new word
774+ next_token_uses_guide_token = (new_token_id == 198 );
732775
733776 common_sampler_accept (smpl[i], new_token_id, true );
734777
0 commit comments