Skip to content

Commit b548695

Browse files
committed
added rudimentary support for outetts v0.3 500m and 1b models
1 parent 6390a99 commit b548695

File tree

1 file changed

+31
-16
lines changed

1 file changed

+31
-16
lines changed

examples/tts/tts.cpp

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) {
371371
}
372372

373373
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
374-
static std::string process_text(const std::string & text) {
374+
static std::string process_text(const std::string & text, bool is_version_0_3) {
375375

376376
// For now I skipped text romanization as I am unsure how to handle
377377
// uroman and MeCab implementations in C++
@@ -401,7 +401,7 @@ static std::string process_text(const std::string & text) {
401401
if (c == ' ') {
402402
prompt_clean += "<|text_sep|>";
403403
*/
404-
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>");
404+
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), is_version_0_3?"<|space|>":"<|text_sep|>");
405405

406406
return processed_text;
407407
}
@@ -425,8 +425,7 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
425425
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
426426
}
427427

428-
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
429-
const std::string& delimiter = "<|text_sep|>";
428+
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const std::string& delimiter) {
430429

431430
std::vector<llama_token> result;
432431
size_t start = 0;
@@ -523,6 +522,11 @@ int main(int argc, char ** argv) {
523522
std::vector<llama_token> codes;
524523
std::vector<llama_token> guide_tokens;
525524

525+
//determine OuteTTS version and vocab code offset. v0.2 does not have <|space|>, but v0.3 does
526+
const bool is_version_0_3 = common_tokenize(vocab,"<|space|>",false,true).size()==1;
527+
//determine the offset of the first audio code token
528+
const int cts_offset = common_tokenize(vocab,"<|0|>",false,true)[0];
529+
526530
// process prompt and generate voice codes
527531
{
528532
LOG_INF("%s: constructing prompt ..\n", __func__);
@@ -531,13 +535,17 @@ int main(int argc, char ** argv) {
531535

532536
prompt_init(prompt_inp, vocab);
533537

534-
prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
538+
if (is_version_0_3) {
539+
prompt_add(prompt_inp, vocab, "<|text_start|>the<|space|>overall<|space|>package<|space|>from<|space|>just<|space|>two<|space|>people<|space|>is<|space|>pretty<|space|>remarkable<|space|>sure<|space|>i<|space|>have<|space|>some<|space|>critiques<|space|>about<|space|>some<|space|>of<|space|>the<|space|>gameplay<|space|>aspects<|space|>but<|space|>its<|space|>still<|space|>really<|space|>enjoyable<|space|>and<|space|>it<|space|>looks<|space|>lovely<|space|>", false, true);
540+
} else {
541+
prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
542+
}
535543

536544
// convert the input text into the necessary format expected by OuteTTS
537545
{
538-
std::string prompt_clean = process_text(params.prompt);
546+
std::string prompt_clean = process_text(params.prompt, is_version_0_3);
539547
if (params.vocoder.use_guide_tokens) {
540-
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
548+
guide_tokens = prepare_guide_tokens(vocab, prompt_clean, is_version_0_3?"<|space|>":"<|text_sep|>");
541549
}
542550

543551
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
@@ -549,8 +557,8 @@ int main(int argc, char ** argv) {
549557

550558
// disabled to save time on tokenizing each time
551559
// TODO: load voices from the json files
552-
#if 0
553-
const std::string voice_data = R"(<|audio_start|>
560+
#if 1
561+
std::string voice_data = R"(<|audio_start|>
554562
the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
555563
overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
556564
package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
@@ -582,12 +590,19 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
582590
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
583591
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
584592

585-
auto tmp = common_tokenize(vocab, voice_data, false, true);
586-
printf("\n\n");
587-
for (int i = 0; i < tmp.size(); ++i) {
588-
printf("%d, ", tmp[i]);
593+
if (is_version_0_3)
594+
{
595+
voice_data = std::regex_replace(voice_data, std::regex(R"(<\|code_start\|>)"), "");
596+
voice_data = std::regex_replace(voice_data, std::regex(R"(<\|code_end\|>)"), "<|space|>");
589597
}
590-
printf("\n\n");
598+
599+
prompt_add(prompt_inp, vocab, voice_data, false, true);
600+
601+
// printf("\n\n");
602+
// for (int i = 0; i < tmp.size(); ++i) {
603+
// printf("%d, ", tmp[i]);
604+
// }
605+
// printf("\n\n");
591606
#else
592607
prompt_add(prompt_inp, llama_tokens {
593608
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
@@ -882,7 +897,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
882897
}
883898

884899
// remove all non-audio tokens (i.e. < 151672 || > 155772)
885-
codes.erase(std::remove_if(codes.begin(), codes.end(), [](llama_token t) { return t < 151672 || t > 155772; }), codes.end());
900+
codes.erase(std::remove_if(codes.begin(), codes.end(), [cts_offset](llama_token t) { return t < cts_offset || t > (cts_offset+4100); }), codes.end());
886901

887902
{
888903
const std::string inp_txt = common_detokenize(ctx_ttc, codes, true);
@@ -891,7 +906,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
891906
}
892907

893908
for (auto & token : codes) {
894-
token -= 151672;
909+
token -= cts_offset;
895910
}
896911

897912
const auto t_voc_start = ggml_time_us();

0 commit comments

Comments
 (0)