Skip to content

Commit 6b25ae9

Browse files
committed
first sync
1 parent 0b864c4 commit 6b25ae9

File tree

1 file changed

+225
-1
lines changed

1 file changed

+225
-1
lines changed

examples/simple-tts/simple-tts.cpp

Lines changed: 225 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,189 @@ static void prompt_init(std::vector<llama_token> & prompt, const llama_vocab * v
139139
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
140140
}
141141

142+
static const std::map<int, std::string> ones = {
143+
{0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"},
144+
{5, "five"}, {6, "six"}, {7, "seven"}, {8, "eight"}, {9, "nine"},
145+
{10, "ten"}, {11, "eleven"}, {12, "twelve"}, {13, "thirteen"}, {14, "fourteen"},
146+
{15, "fifteen"}, {16, "sixteen"}, {17, "seventeen"}, {18, "eighteen"}, {19, "nineteen"}
147+
};
148+
149+
static const std::map<int, std::string> tens = {
150+
{2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"},
151+
{6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"}
152+
};
153+
154+
// Convert a number less than 1000 to words
155+
static std::string convert_less_than_thousand(int num) {
156+
std::string result;
157+
158+
if (num >= 100) {
159+
result += ones.at(num / 100) + " hundred ";
160+
num %= 100;
161+
}
162+
163+
if (num >= 20) {
164+
result += tens.at(num / 10);
165+
if (num % 10 > 0) {
166+
result += "-" + ones.at(num % 10);
167+
}
168+
} else if (num > 0) {
169+
result += ones.at(num);
170+
}
171+
172+
return result;
173+
}
174+
175+
static std::string number_to_words(const std::string & number_str) {
176+
try {
177+
size_t decimal_pos = number_str.find('.');
178+
std::string integer_part = number_str.substr(0, decimal_pos);
179+
180+
int int_number = std::stoi(integer_part);
181+
std::string result;
182+
183+
if (int_number == 0) {
184+
result = "zero";
185+
} else {
186+
if (int_number >= 1000000000) {
187+
int billions = int_number / 1000000000;
188+
result += convert_less_than_thousand(billions) + " billion ";
189+
int_number %= 1000000000;
190+
}
191+
192+
if (int_number >= 1000000) {
193+
int millions = int_number / 1000000;
194+
result += convert_less_than_thousand(millions) + " million ";
195+
int_number %= 1000000;
196+
}
197+
198+
if (int_number >= 1000) {
199+
int thousands = int_number / 1000;
200+
result += convert_less_than_thousand(thousands) + " thousand ";
201+
int_number %= 1000;
202+
}
203+
204+
if (int_number > 0) {
205+
result += convert_less_than_thousand(int_number);
206+
}
207+
}
208+
209+
// Handle decimal part
210+
if (decimal_pos != std::string::npos) {
211+
result += " point";
212+
std::string decimal_part = number_str.substr(decimal_pos + 1);
213+
for (char digit : decimal_part) {
214+
result += " " + ones.at(digit - '0');
215+
}
216+
}
217+
218+
return result;
219+
} catch (const std::exception& e) {
220+
// Skip if fails
221+
return " ";
222+
}
223+
}
224+
225+
static std::string replace_numbers_with_words(const std::string & input_text) {
226+
std::regex number_pattern(R"(\d+(\.\d+)?)");
227+
std::string result;
228+
auto it = std::sregex_iterator(input_text.begin(), input_text.end(), number_pattern);
229+
auto end = std::sregex_iterator();
230+
231+
size_t last_pos = 0;
232+
for (std::sregex_iterator i = it; i != end; ++i) {
233+
const std::smatch& match = *i;
234+
result.append(input_text, last_pos, match.position() - last_pos);
235+
result.append(number_to_words(match.str()));
236+
last_pos = match.position() + match.length();
237+
}
238+
result.append(input_text, last_pos);
239+
240+
return result;
241+
}
242+
243+
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
244+
static std::string process_text(const std::string & text, const outetts_version tts_version = OUTETTS_V0_2) {
245+
246+
// For now I skipped text romanization as I am unsure how to handle
247+
// uroman and MeCab implementations in C++
248+
// maybe something like https://github.com/anyascii/anyascii/ could work.
249+
// currently only English would be supported in this function
250+
251+
std::string processed_text = replace_numbers_with_words(text);
252+
253+
std::transform(processed_text.begin(), processed_text.end(),
254+
processed_text.begin(), ::tolower);
255+
256+
std::regex special_chars(R"([-_/,\.\\])");
257+
processed_text = std::regex_replace(processed_text, special_chars, " ");
258+
259+
std::regex non_alpha(R"([^a-z\s])");
260+
processed_text = std::regex_replace(processed_text, non_alpha, "");
261+
262+
std::regex multiple_spaces(R"(\s+)");
263+
processed_text = std::regex_replace(processed_text, multiple_spaces, " ");
264+
265+
processed_text = std::regex_replace(processed_text, std::regex(R"(^\s+|\s+$)"), "");
266+
267+
/*
268+
Replace spaces with the separator token same as in line 365
269+
270+
for (auto & c : prompt_user) {
271+
if (c == ' ') {
272+
prompt_clean += "<|text_sep|>";
273+
*/
274+
std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
275+
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), separator);
276+
277+
return processed_text;
278+
}
279+
280+
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const outetts_version tts_version = OUTETTS_V0_2) {
281+
const std::string& delimiter = (tts_version == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>");
282+
283+
std::vector<llama_token> result;
284+
size_t start = 0;
285+
size_t end = str.find(delimiter);
286+
287+
//first token is always a newline, as it was not previously added
288+
result.push_back(llama_token_nl(vocab));
289+
290+
while (end != std::string::npos) {
291+
std::string current_word = str.substr(start, end - start);
292+
std::vector<llama_token> tmp(current_word.size());
293+
auto n_tmp = llama_tokenize(vocab, current_word.c_str(), current_word.size(), tmp.data(), tmp.size(), false, true);
294+
tmp.resize(n_tmp);
295+
result.insert(result.end(), tmp.begin(), tmp.end());
296+
start = end + delimiter.length();
297+
end = str.find(delimiter, start);
298+
}
299+
300+
// Add the last part
301+
std::string current_word = str.substr(start);
302+
std::vector<llama_token> tmp(current_word.size());
303+
auto n_tmp = llama_tokenize(vocab, current_word.c_str(), current_word.size(), tmp.data(), tmp.size(), false, true);
304+
tmp.resize(n_tmp);
305+
if (tmp.size() > 0) {
306+
result.insert(result.end(), tmp.begin(), tmp.end());
307+
}
308+
return result;
309+
}
310+
311+
void batch_add(struct llama_batch & batch, llama_token id,llama_pos pos, const std::vector<llama_seq_id> & seq_ids, bool logits) {
312+
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
313+
314+
batch.token [batch.n_tokens] = id;
315+
batch.pos [batch.n_tokens] = pos;
316+
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
317+
for (size_t i = 0; i < seq_ids.size(); ++i) {
318+
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
319+
}
320+
batch.logits [batch.n_tokens] = logits;
321+
322+
batch.n_tokens++;
323+
}
324+
142325
static void print_usage(int, char ** argv) {
143326
printf("\nexample usage:\n");
144327
printf("\n %s -m model.gguf -mv vocoder.gguf -v en_male_1.json -p \"Hello!\"\n", argv[0]);
@@ -251,6 +434,47 @@ int main(int argc, char ** argv) {
251434
std::string audio_text = audio_text_from_speaker(speaker, tts_version);
252435
std::string audio_data = audio_data_from_speaker(speaker, tts_version);
253436

437+
std::vector<llama_token> prompt_inp;
438+
439+
const llama_vocab * vocab = llama_model_get_vocab(model);
440+
441+
prompt_init(prompt_inp, vocab);
442+
443+
prompt_add(prompt_inp, vocab, audio_text, false, true);
444+
445+
std::string prompt_clean = process_text(prompt, tts_version);
446+
447+
std::vector<llama_token> guide_tokens = prepare_guide_tokens(vocab, prompt_clean, tts_version);
448+
449+
prompt_add(prompt_inp, vocab, prompt_clean, false, true);
450+
451+
prompt_add(prompt_inp, vocab, "<|text_end|>\n", false, true);
452+
453+
prompt_add(prompt_inp, vocab, audio_data, false, true);
454+
455+
// create a llama_batch
456+
// we use this object to submit token data for decoding
457+
llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel);
458+
459+
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
460+
for (int32_t i = 0; i < n_parallel; ++i) {
461+
seq_ids[i] = i;
462+
}
463+
464+
// evaluate the initial prompt
465+
for (size_t i = 0; i < prompt_inp.size(); ++i) {
466+
batch_add(batch, prompt_inp[i], i, seq_ids, false);
467+
}
468+
469+
// llama_decode will output logits only for the last token of the prompt
470+
batch.logits[batch.n_tokens - 1] = true;
471+
472+
if (llama_decode(ctx, batch) != 0) {
473+
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
474+
return 1;
475+
}
476+
477+
llama_synchronize(ctx);
478+
254479
std::vector<llama_token> codes;
255-
std::vector<llama_token> guide_tokens;
256480
}

0 commit comments

Comments
 (0)