Skip to content

Commit c9d1eb3

Browse files
committed
Added the ability to use guide tokens for OuteTTS, greatly improving TTS recitation accuracy over long input sequences.
1 parent 2739a71 commit c9d1eb3

File tree

3 files changed

+51
-1
lines changed

3 files changed

+51
-1
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2214,6 +2214,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
22142214
params.vocoder.model = value;
22152215
}
22162216
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
2217+
add_opt(common_arg(
2218+
{"--tts-use-guide-tokens"},
2219+
"Use guide tokens to improve TTS word recall",
2220+
[](common_params & params) {
2221+
params.vocoder.use_guide_tokens = true;
2222+
}
2223+
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
22172224

22182225
// model-specific
22192226
add_opt(common_arg(

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ struct common_params_vocoder {
178178

179179
std::string model = ""; // model path // NOLINT
180180
std::string model_url = ""; // model url to download // NOLINT
181+
182+
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
181183
};
182184

183185
struct common_params {

examples/tts/tts.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,29 @@ static void prompt_init(llama_tokens & prompt, const llama_model * model) {
425425
prompt_add(prompt, model, "<|im_start|>\n", true, true);
426426
}
427427

428+
static std::vector<llama_token> prepare_guide_tokens(const llama_model * model, const std::string& str)
429+
{
430+
const std::string& delimiter = "<|text_sep|>";
431+
432+
std::vector<llama_token> result;
433+
size_t start = 0;
434+
size_t end = str.find(delimiter);
435+
436+
while (end != std::string::npos) {
437+
std::string current_word = str.substr(start, end - start);
438+
auto tmp = common_tokenize(model, current_word, false, true);
439+
result.push_back(tmp[0]);
440+
start = end + delimiter.length();
441+
end = str.find(delimiter, start);
442+
}
443+
444+
// Add the last part
445+
std::string current_word = str.substr(start);
446+
auto tmp = common_tokenize(model, current_word, false, true);
447+
result.push_back(tmp[0]);
448+
return result;
449+
}
450+
428451
int main(int argc, char ** argv) {
429452
common_params params;
430453

@@ -492,6 +515,7 @@ int main(int argc, char ** argv) {
492515
const auto t_main_start = ggml_time_us();
493516

494517
std::vector<llama_token> codes;
518+
std::vector<llama_token> guide_tokens;
495519

496520
// process prompt and generate voice codes
497521
{
@@ -506,6 +530,10 @@ int main(int argc, char ** argv) {
506530
// convert the input text into the necessary format expected by OuteTTS
507531
{
508532
std::string prompt_clean = process_text(params.prompt);
533+
if(params.vocoder.use_guide_tokens)
534+
{
535+
guide_tokens = prepare_guide_tokens(model_ttc,prompt_clean);
536+
}
509537

510538
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
511539

@@ -715,6 +743,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
715743
int n_past = batch.n_tokens;
716744
int n_decode = 0;
717745

746+
bool next_token_uses_guide_token = true;
747+
718748
while (n_decode <= n_predict) {
719749
// prepare the next batch
720750
common_batch_clear(batch);
@@ -726,7 +756,18 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
726756
continue;
727757
}
728758

729-
const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
759+
llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
760+
761+
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
762+
if(!guide_tokens.empty() && next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
763+
{
764+
llama_token guide_token = guide_tokens[0];
765+
guide_tokens.erase(guide_tokens.begin());
766+
new_token_id = guide_token; //ensure correct word fragment is used
767+
}
768+
769+
//this is the token id that always precedes a new word
770+
next_token_uses_guide_token = (new_token_id == 198);
730771

731772
common_sampler_accept(smpl[i], new_token_id, true);
732773

0 commit comments

Comments
 (0)