@@ -369,17 +369,56 @@ static std::vector<llama_token> prepare_guide_tokens(const llama_model * model,
369369
370370 // Add the last part
371371 std::string current_word = str.substr (start);
372- auto tmp = common_tokenize (model, current_word, false , true );
373- result.push_back (tmp[0 ]);
372+ if (current_word!=" " )
373+ {
374+ auto tmp = common_tokenize (model, current_word, false , true );
375+ if (tmp.size ()>0 ){
376+ result.push_back (tmp[0 ]);
377+ }
378+ }
374379 return result;
375380}
376381
382+ std::string trim_words (const std::string& input, const std::string& separator, size_t maxWords) {
383+ // Split the input string by the separator
384+ std::vector<std::string> words;
385+ size_t start = 0 , end;
386+ while ((end = input.find (separator, start)) != std::string::npos) {
387+ std::string last = input.substr (start, end - start);
388+ if (last != " " ) {
389+ words.push_back (last);
390+ }
391+ start = end + separator.length ();
392+ }
393+ std::string last = input.substr (start);
394+ if (last!=" " )
395+ {
396+ words.push_back (last); // Add the last word
397+ }
398+
399+ // Ensure no more than maxWords are kept
400+ if (words.size () > maxWords) {
401+ words.resize (maxWords);
402+ }
403+
404+ // Reconstruct the string with the separator
405+ std::ostringstream result;
406+ for (size_t i = 0 ; i < words.size (); ++i) {
407+ if (i > 0 ) result << separator;
408+ result << words[i];
409+ }
410+
411+ return result.str ();
412+ }
413+
377414static llama_context * ttc_ctx = nullptr ; // text to codes ctx
378415static llama_context * cts_ctx = nullptr ; // codes to speech
379416
380417static int ttsdebugmode = 0 ;
381418static std::string ttsplatformenv, ttsdeviceenv, ttsvulkandeviceenv;
382419static std::string last_generated_audio = " " ;
420+ static std::vector<llama_token> last_speaker_codes; // will store cached speaker
421+ static int last_speaker_seed = -999 ;
383422
384423bool ttstype_load_model (const tts_load_model_inputs inputs)
385424{
@@ -484,14 +523,11 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
484523 const llama_model * model_cts = &(cts_ctx->model );
485524 const int ttc_n_vocab = llama_n_vocab (model_ttc);
486525 std::string prompt = inputs.prompt ;
487-
488- if (!inputs.quiet )
489- {
490- printf (" \n TTS Generating... " );
491- }
526+ const std::string sampletext = " but<|text_sep|>that<|text_sep|>is<|text_sep|>what<|text_sep|>it<|text_sep|>is" ;
492527
493528 // process prompt and generate voice codes
494-
529+ llama_kv_cache_clear (ttc_ctx);
530+ llama_kv_cache_clear (cts_ctx);
495531 std::vector<llama_token> prompt_inp;
496532 prompt_init (prompt_inp, model_ttc);
497533 prompt_add (prompt_inp, model_ttc, " <|text_start|>" , false , true );
@@ -501,39 +537,38 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
501537 if (speaker_seed <= 0 || speaker_seed==0xFFFFFFFF )
502538 {
503539 speaker_seed = (((uint32_t )time (NULL )) % 1000000u );
504- if (ttsdebugmode==1 )
505- {
506- printf (" \n Using Speaker Seed: %d" , speaker_seed);
507- }
508540 }
509541 if (audio_seed <= 0 || audio_seed==0xFFFFFFFF )
510542 {
511543 audio_seed = (((uint32_t )time (NULL )) % 1000000u );
512- if (ttsdebugmode==1 )
513- {
514- printf (" \n Using Audio Seed: %d" , audio_seed);
515- }
544+ }
545+ if (ttsdebugmode==1 )
546+ {
547+ printf (" \n Using Speaker Seed: %d" , speaker_seed);
548+ printf (" \n Using Audio Seed: %d" , audio_seed);
516549 }
517550
518551 std::mt19937 tts_rng (audio_seed);
519552 std::mt19937 speaker_rng (speaker_seed);
520553
521- // add the speaker based on the seed
522- if (speaker_seed>0 )
523- {
524- std::string sampletext = " but<|text_sep|>that<|text_sep|>is<|text_sep|>what<|text_sep|>it<|text_sep|>is<|text_sep|>" ;
525- }
554+ int n_decode = 0 ;
555+ int n_predict = 2048 ; // will be updated later
556+ bool next_token_uses_guide_token = true ;
526557
527558 // convert the input text into the necessary format expected by OuteTTS
528559 std::string prompt_clean = process_text (prompt);
529560
561+ // further clean it by keeping only the last 300 words
562+ prompt_clean = trim_words (prompt_clean," <|text_sep|>" ,300 );
563+
530564 if (prompt_clean.size ()==0 )
531565 {
532566 // no input
533567 if (!inputs.quiet )
534568 {
535569 printf (" \n TTS sent empty input.\n " );
536- output.data = " " ;
570+ last_generated_audio = " " ;
571+ output.data = last_generated_audio.c_str ();
537572 output.status = 1 ;
538573 return output;
539574 }
@@ -544,19 +579,130 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
544579 printf (" \n Input: %s\n " , prompt_clean.c_str ());
545580 }
546581
582+ // 2 passes. first pass, we generate the speaker voice if required, then cache it for reuse
583+ // second pass, we use the speaker snipper to align output voice to match the desired speaker
584+ if (speaker_seed>0 ) // first pass
585+ {
586+ // if we have a cached speaker, reuse it
587+ if (last_speaker_seed==speaker_seed && !last_speaker_codes.empty ())
588+ {
589+ // able to proceed, do nothing
590+ if (!inputs.quiet && ttsdebugmode==1 )
591+ {
592+ printf (" \n Reuse speaker ID=%d (%d tokens)..." , last_speaker_seed, last_speaker_codes.size ());
593+ }
594+ } else {
595+ // generate the voice texture of our new speaker
596+ last_speaker_codes.clear ();
597+ guide_tokens = prepare_guide_tokens (model_ttc,sampletext);
598+ prompt_add (prompt_inp, model_ttc, sampletext, false , true );
599+ prompt_add (prompt_inp, model_ttc, " <|text_end|>\n <|audio_start|>\n " , false , true );
600+ if (!inputs.quiet && ttsdebugmode==1 )
601+ {
602+ printf (" \n Prepare new speaker (%d input tokens)..." , prompt_inp.size ());
603+ }
604+ kcpp_embd_batch tts_batch = kcpp_embd_batch (prompt_inp, 0 , false , true );
605+ auto evalok = (llama_decode (ttc_ctx, tts_batch.batch )==0 );
606+ if (!evalok) {
607+ printf (" \n Error: TTS prompt batch processing failed\n " );
608+ output.data = " " ;
609+ output.status = 0 ;
610+ return output;
611+ }
612+
613+ while (n_decode <= n_predict)
614+ {
615+ float * logits = llama_get_logits (ttc_ctx);
616+
617+ // use creative settings to generate speakers
618+ const int topk = 20 ;
619+ const float temp = 1 .2f ;
620+ llama_token new_token_id = kcpp_quick_sample (logits,ttc_n_vocab,topk,temp,speaker_rng);
621+
622+ // guide tokens help prevent hallucinations by forcing the TTS to use the correct word
623+ if (next_token_uses_guide_token && !llama_token_is_control (model_ttc, new_token_id) && !llama_token_is_eog (model_ttc, new_token_id))
624+ {
625+ if (!guide_tokens.empty ())
626+ {
627+ llama_token guide_token = guide_tokens[0 ];
628+ guide_tokens.erase (guide_tokens.begin ());
629+ new_token_id = guide_token; // ensure correct word fragment is used
630+ } else {
631+ n_decode = n_predict; // stop generation
632+ }
633+ }
634+
635+ // this is the token id that always precedes a new word
636+ next_token_uses_guide_token = (new_token_id == 198 );
637+ last_speaker_codes.push_back (new_token_id);
638+
639+ // is it an end of generation? -> mark the stream as finished
640+ if (llama_token_is_eog (model_ttc, new_token_id) || n_decode >= n_predict) {
641+ break ;
642+ }
643+
644+ n_decode += 1 ;
645+ std::vector<llama_token> next = {new_token_id};
646+ llama_batch batch = llama_batch_get_one (next.data (), next.size ());
647+
648+ // evaluate the current batch with the transformer model
649+ if (llama_decode (ttc_ctx, batch)) {
650+ printf (" \n Error: TTS code generation failed!\n " );
651+ output.data = " " ;
652+ output.status = 0 ;
653+ return output;
654+ }
655+ }
656+
657+ // trim everything after final <|code_end|>
658+ auto it = std::find (last_speaker_codes.rbegin (), last_speaker_codes.rend (), 151670 );
659+ if (it != last_speaker_codes.rend ()) {
660+ // Erase elements after the found 999 (inclusive)
661+ last_speaker_codes.erase (it.base (), last_speaker_codes.end ());
662+ }
663+ last_speaker_seed = speaker_seed;
664+ if (!inputs.quiet && ttsdebugmode==1 )
665+ {
666+ printf (" \n New speaker ID=%d created (%d tokens)..." , last_speaker_seed, last_speaker_codes.size ());
667+ const std::string inp_txt = common_detokenize (ttc_ctx, last_speaker_codes, true );
668+ printf (" \n %s\n " , inp_txt.c_str ());
669+ }
670+ }
671+ guide_tokens.clear ();
672+ llama_kv_cache_clear (ttc_ctx);
673+ prompt_init (prompt_inp, model_ttc);
674+ prompt_add (prompt_inp, model_ttc, " <|text_start|>" , false , true );
675+ next_token_uses_guide_token = true ;
676+ }
677+
678+ // second pass: add the speaker before the actual prompt
547679 guide_tokens = prepare_guide_tokens (model_ttc,prompt_clean);
680+ if (speaker_seed > 0 )
681+ {
682+ prompt_clean = sampletext + " <|text_sep|>" + prompt_clean;
683+ }
548684 prompt_add (prompt_inp, model_ttc, prompt_clean, false , true );
549685
550686 if (!inputs.quiet )
551687 {
552- printf (" (%d input words)..." , guide_tokens.size ());
688+ printf (" \n TTS Generating (%d input tokens)..." , prompt_inp.size ());
689+ }
690+
691+ prompt_add (prompt_inp, model_ttc, " <|text_end|>\n <|audio_start|>\n " , false , true );
692+
693+ if (!last_speaker_codes.empty () && speaker_seed > 0 ) // apply speaker voice output
694+ {
695+ prompt_add (prompt_inp, last_speaker_codes);
553696 }
554697
555- prompt_add (prompt_inp, model_ttc, " <|text_end|>\n " , false , true );
698+ if (!inputs.quiet && ttsdebugmode==1 )
699+ {
700+ printf (" \n DUMP TTS PROMPT (%d tokens):\n " , prompt_inp.size ());
701+ const std::string inp_txt = common_detokenize (ttc_ctx, prompt_inp, true );
702+ printf (" \n %s\n " , inp_txt.c_str ());
703+ }
556704
557705 // create batch with tokens for decoding prompt processing
558- llama_kv_cache_clear (ttc_ctx);
559- llama_kv_cache_clear (cts_ctx);
560706 kcpp_embd_batch tts_batch = kcpp_embd_batch (prompt_inp, 0 , false , true );
561707
562708 auto evalok = (llama_decode (ttc_ctx, tts_batch.batch )==0 );
@@ -568,28 +714,33 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
568714 }
569715
570716 // main loop
571- int n_decode = 0 ;
572- int n_predict = 4096 ; // max 4096 tokens
573-
574- bool next_token_uses_guide_token = true ;
717+ n_decode = 0 ;
718+ n_predict = 4096 ; // max 4096 tokens
575719
576720 while (n_decode <= n_predict)
577721 {
578722 float * logits = llama_get_logits (ttc_ctx);
579723
580- llama_token new_token_id = kcpp_quick_sample (logits,ttc_n_vocab,20 ,1.0 ,tts_rng);
724+ // use predictable settings to generate voice
725+ const int topk = 4 ;
726+ const float temp = 0 .75f ;
727+ llama_token new_token_id = kcpp_quick_sample (logits,ttc_n_vocab,topk,temp,tts_rng);
581728
582729 // guide tokens help prevent hallucinations by forcing the TTS to use the correct word
583- if (!guide_tokens. empty () && next_token_uses_guide_token && !llama_token_is_control (model_ttc, new_token_id) && !llama_token_is_eog (model_ttc, new_token_id))
730+ if (next_token_uses_guide_token && !llama_token_is_control (model_ttc, new_token_id) && !llama_token_is_eog (model_ttc, new_token_id))
584731 {
585- llama_token guide_token = guide_tokens[0 ];
586- guide_tokens.erase (guide_tokens.begin ());
587- new_token_id = guide_token; // ensure correct word fragment is used
732+ if (!guide_tokens.empty ())
733+ {
734+ llama_token guide_token = guide_tokens[0 ];
735+ guide_tokens.erase (guide_tokens.begin ());
736+ new_token_id = guide_token; // ensure correct word fragment is used
737+ } else {
738+ n_decode = n_predict; // end generation
739+ }
588740 }
589741
590742 // this is the token id that always precedes a new word
591743 next_token_uses_guide_token = (new_token_id == 198 );
592-
593744 codes.push_back (new_token_id);
594745
595746 // is it an end of generation? -> mark the stream as finished
@@ -613,7 +764,6 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
613764 if (!inputs.quiet && ttsdebugmode==1 )
614765 {
615766 const std::string inp_txt = common_detokenize (ttc_ctx, codes, true );
616-
617767 printf (" \n Generated %d Codes: '%s'\n " ,codes.size (), inp_txt.c_str ());
618768 }
619769
@@ -628,8 +778,9 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
628778 if (n_codes<=1 )
629779 {
630780 printf (" \n Warning: TTS vocoder generated nothing!\n " );
631- output.data = " " ;
632- output.status = 0 ;
781+ last_generated_audio = " " ;
782+ output.data = last_generated_audio.c_str ();
783+ output.status = 1 ;
633784 return output;
634785 }
635786 kcpp_embd_batch codebatch = kcpp_embd_batch (codes,0 ,false ,true );
@@ -649,8 +800,9 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
649800
650801 const int n_sr = 24000 ; // sampling rate
651802
652- // zero out first 0.05 seconds
653- for (int i = 0 ; i < 24000 /20 ; ++i) {
803+ // zero out first 0.25 seconds or 0.05 depending on whether its seeded
804+ const int cutout = (speaker_seed>0 ?(24000 /4 ):(24000 /20 ));
805+ for (int i = 0 ; i < cutout; ++i) {
654806 audio[i] = 0 .0f ;
655807 }
656808 // add some silence at the end
0 commit comments