@@ -139,6 +139,189 @@ static void prompt_init(std::vector<llama_token> & prompt, const llama_vocab * v
139139 prompt_add (prompt, vocab, " <|im_start|>\n " , true , true );
140140}
141141
142+ static const std::map<int , std::string> ones = {
143+ {0 , " zero" }, {1 , " one" }, {2 , " two" }, {3 , " three" }, {4 , " four" },
144+ {5 , " five" }, {6 , " six" }, {7 , " seven" }, {8 , " eight" }, {9 , " nine" },
145+ {10 , " ten" }, {11 , " eleven" }, {12 , " twelve" }, {13 , " thirteen" }, {14 , " fourteen" },
146+ {15 , " fifteen" }, {16 , " sixteen" }, {17 , " seventeen" }, {18 , " eighteen" }, {19 , " nineteen" }
147+ };
148+
149+ static const std::map<int , std::string> tens = {
150+ {2 , " twenty" }, {3 , " thirty" }, {4 , " forty" }, {5 , " fifty" },
151+ {6 , " sixty" }, {7 , " seventy" }, {8 , " eighty" }, {9 , " ninety" }
152+ };
153+
154+ // Convert a number less than 1000 to words
155+ static std::string convert_less_than_thousand (int num) {
156+ std::string result;
157+
158+ if (num >= 100 ) {
159+ result += ones.at (num / 100 ) + " hundred " ;
160+ num %= 100 ;
161+ }
162+
163+ if (num >= 20 ) {
164+ result += tens.at (num / 10 );
165+ if (num % 10 > 0 ) {
166+ result += " -" + ones.at (num % 10 );
167+ }
168+ } else if (num > 0 ) {
169+ result += ones.at (num);
170+ }
171+
172+ return result;
173+ }
174+
175+ static std::string number_to_words (const std::string & number_str) {
176+ try {
177+ size_t decimal_pos = number_str.find (' .' );
178+ std::string integer_part = number_str.substr (0 , decimal_pos);
179+
180+ int int_number = std::stoi (integer_part);
181+ std::string result;
182+
183+ if (int_number == 0 ) {
184+ result = " zero" ;
185+ } else {
186+ if (int_number >= 1000000000 ) {
187+ int billions = int_number / 1000000000 ;
188+ result += convert_less_than_thousand (billions) + " billion " ;
189+ int_number %= 1000000000 ;
190+ }
191+
192+ if (int_number >= 1000000 ) {
193+ int millions = int_number / 1000000 ;
194+ result += convert_less_than_thousand (millions) + " million " ;
195+ int_number %= 1000000 ;
196+ }
197+
198+ if (int_number >= 1000 ) {
199+ int thousands = int_number / 1000 ;
200+ result += convert_less_than_thousand (thousands) + " thousand " ;
201+ int_number %= 1000 ;
202+ }
203+
204+ if (int_number > 0 ) {
205+ result += convert_less_than_thousand (int_number);
206+ }
207+ }
208+
209+ // Handle decimal part
210+ if (decimal_pos != std::string::npos) {
211+ result += " point" ;
212+ std::string decimal_part = number_str.substr (decimal_pos + 1 );
213+ for (char digit : decimal_part) {
214+ result += " " + ones.at (digit - ' 0' );
215+ }
216+ }
217+
218+ return result;
219+ } catch (const std::exception& e) {
220+ // Skip if fails
221+ return " " ;
222+ }
223+ }
224+
225+ static std::string replace_numbers_with_words (const std::string & input_text) {
226+ std::regex number_pattern (R"( \d+(\.\d+)?)" );
227+ std::string result;
228+ auto it = std::sregex_iterator (input_text.begin (), input_text.end (), number_pattern);
229+ auto end = std::sregex_iterator ();
230+
231+ size_t last_pos = 0 ;
232+ for (std::sregex_iterator i = it; i != end; ++i) {
233+ const std::smatch& match = *i;
234+ result.append (input_text, last_pos, match.position () - last_pos);
235+ result.append (number_to_words (match.str ()));
236+ last_pos = match.position () + match.length ();
237+ }
238+ result.append (input_text, last_pos);
239+
240+ return result;
241+ }
242+
243+ // Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
244+ static std::string process_text (const std::string & text, const outetts_version tts_version = OUTETTS_V0_2) {
245+
246+ // For now I skipped text romanization as I am unsure how to handle
247+ // uroman and MeCab implementations in C++
248+ // maybe something like https://github.com/anyascii/anyascii/ could work.
249+ // currently only English would be supported in this function
250+
251+ std::string processed_text = replace_numbers_with_words (text);
252+
253+ std::transform (processed_text.begin (), processed_text.end (),
254+ processed_text.begin (), ::tolower);
255+
256+ std::regex special_chars (R"( [-_/,\.\\])" );
257+ processed_text = std::regex_replace (processed_text, special_chars, " " );
258+
259+ std::regex non_alpha (R"( [^a-z\s])" );
260+ processed_text = std::regex_replace (processed_text, non_alpha, " " );
261+
262+ std::regex multiple_spaces (R"( \s+)" );
263+ processed_text = std::regex_replace (processed_text, multiple_spaces, " " );
264+
265+ processed_text = std::regex_replace (processed_text, std::regex (R"( ^\s+|\s+$)" ), " " );
266+
267+ /*
268+ Replace spaces with the separator token same as in line 365
269+
270+ for (auto & c : prompt_user) {
271+ if (c == ' ') {
272+ prompt_clean += "<|text_sep|>";
273+ */
274+ std::string separator = (tts_version == OUTETTS_V0_3) ? " <|space|>" : " <|text_sep|>" ;
275+ processed_text = std::regex_replace (processed_text, std::regex (R"( \s)" ), separator);
276+
277+ return processed_text;
278+ }
279+
280+ static std::vector<llama_token> prepare_guide_tokens (const llama_vocab * vocab, const std::string & str, const outetts_version tts_version = OUTETTS_V0_2) {
281+ const std::string& delimiter = (tts_version == OUTETTS_V0_3 ? " <|space|>" : " <|text_sep|>" );
282+
283+ std::vector<llama_token> result;
284+ size_t start = 0 ;
285+ size_t end = str.find (delimiter);
286+
287+ // first token is always a newline, as it was not previously added
288+ result.push_back (llama_token_nl (vocab));
289+
290+ while (end != std::string::npos) {
291+ std::string current_word = str.substr (start, end - start);
292+ std::vector<llama_token> tmp (current_word.size ());
293+ auto n_tmp = llama_tokenize (vocab, current_word.c_str (), current_word.size (), tmp.data (), tmp.size (), false , true );
294+ tmp.resize (n_tmp);
295+ result.insert (result.end (), tmp.begin (), tmp.end ());
296+ start = end + delimiter.length ();
297+ end = str.find (delimiter, start);
298+ }
299+
300+ // Add the last part
301+ std::string current_word = str.substr (start);
302+ std::vector<llama_token> tmp (current_word.size ());
303+ auto n_tmp = llama_tokenize (vocab, current_word.c_str (), current_word.size (), tmp.data (), tmp.size (), false , true );
304+ tmp.resize (n_tmp);
305+ if (tmp.size () > 0 ) {
306+ result.insert (result.end (), tmp.begin (), tmp.end ());
307+ }
308+ return result;
309+ }
310+
311+ void batch_add (struct llama_batch & batch, llama_token id,llama_pos pos, const std::vector<llama_seq_id> & seq_ids, bool logits) {
312+ GGML_ASSERT (batch.seq_id [batch.n_tokens ] && " llama_batch size exceeded" );
313+
314+ batch.token [batch.n_tokens ] = id;
315+ batch.pos [batch.n_tokens ] = pos;
316+ batch.n_seq_id [batch.n_tokens ] = seq_ids.size ();
317+ for (size_t i = 0 ; i < seq_ids.size (); ++i) {
318+ batch.seq_id [batch.n_tokens ][i] = seq_ids[i];
319+ }
320+ batch.logits [batch.n_tokens ] = logits;
321+
322+ batch.n_tokens ++;
323+ }
324+
142325static void print_usage (int , char ** argv) {
143326 printf (" \n example usage:\n " );
144327 printf (" \n %s -m model.gguf -mv vocoder.gguf -v en_male_1.json -p \" Hello!\"\n " , argv[0 ]);
@@ -251,6 +434,47 @@ int main(int argc, char ** argv) {
251434 std::string audio_text = audio_text_from_speaker (speaker, tts_version);
252435 std::string audio_data = audio_data_from_speaker (speaker, tts_version);
253436
437+ std::vector<llama_token> prompt_inp;
438+
439+ const llama_vocab * vocab = llama_model_get_vocab (model);
440+
441+ prompt_init (prompt_inp, vocab);
442+
443+ prompt_add (prompt_inp, vocab, audio_text, false , true );
444+
445+ std::string prompt_clean = process_text (prompt, tts_version);
446+
447+ std::vector<llama_token> guide_tokens = prepare_guide_tokens (vocab, prompt_clean, tts_version);
448+
449+ prompt_add (prompt_inp, vocab, prompt_clean, false , true );
450+
451+ prompt_add (prompt_inp, vocab, " <|text_end|>\n " , false , true );
452+
453+ prompt_add (prompt_inp, vocab, audio_data, false , true );
454+
455+ // create a llama_batch
456+ // we use this object to submit token data for decoding
457+ llama_batch batch = llama_batch_init (std::max (prompt_inp.size (), (size_t ) n_parallel), 0 , n_parallel);
458+
459+ std::vector<llama_seq_id> seq_ids (n_parallel, 0 );
460+ for (int32_t i = 0 ; i < n_parallel; ++i) {
461+ seq_ids[i] = i;
462+ }
463+
464+ // evaluate the initial prompt
465+ for (size_t i = 0 ; i < prompt_inp.size (); ++i) {
466+ batch_add (batch, prompt_inp[i], i, seq_ids, false );
467+ }
468+
469+ // llama_decode will output logits only for the last token of the prompt
470+ batch.logits [batch.n_tokens - 1 ] = true ;
471+
472+ if (llama_decode (ctx, batch) != 0 ) {
473+ fprintf (stderr, " %s: llama_decode() failed\n " , __func__);
474+ return 1 ;
475+ }
476+
477+ llama_synchronize (ctx);
478+
254479 std::vector<llama_token> codes;
255- std::vector<llama_token> guide_tokens;
256480}
0 commit comments