Skip to content

Commit 62e33d0

Browse files
committed
added support for seeded tts voices
1 parent b3de159 commit 62e33d0

File tree

2 files changed

+211
-44
lines changed

2 files changed

+211
-44
lines changed

koboldcpp.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,16 @@ def bring_terminal_to_foreground():
622622
ctypes.windll.user32.ShowWindow(ctypes.windll.kernel32.GetConsoleWindow(), 9)
623623
ctypes.windll.user32.SetForegroundWindow(ctypes.windll.kernel32.GetConsoleWindow())
624624

625+
def simple_lcg_hash(input_string): #turns any string into a number between 10000 and 99999
626+
a = 1664525
627+
c = 1013904223
628+
m = 89999 # Modulo
629+
hash_value = 25343
630+
for char in input_string:
631+
hash_value = (a * hash_value + ord(char) + c) % m
632+
hash_value += 10000
633+
return hash_value
634+
625635
def string_has_overlap(str_a, str_b, maxcheck):
626636
max_overlap = min(maxcheck, len(str_a), len(str_b))
627637
for i in range(1, max_overlap + 1):
@@ -1331,11 +1341,13 @@ def tts_load_model(ttc_model_filename,cts_model_filename):
13311341
def tts_generate(genparams):
13321342
global args
13331343
is_quiet = True if (args.quiet or args.debugmode == -1) else False
1334-
prompt = genparams.get("input", "")
1344+
prompt = genparams.get("input", genparams.get("text", ""))
13351345
prompt = prompt.strip()
1346+
voicestr = genparams.get("voice", genparams.get("speaker_wav", ""))
1347+
voice = simple_lcg_hash(voicestr) if voicestr else 1
13361348
inputs = tts_generation_inputs()
13371349
inputs.prompt = prompt.encode("UTF-8")
1338-
inputs.speaker_seed = 0
1350+
inputs.speaker_seed = voice
13391351
inputs.audio_seed = 0
13401352
inputs.quiet = is_quiet
13411353
ret = handle.tts_generate(inputs)
@@ -2296,6 +2308,9 @@ def do_GET(self):
22962308
elif self.path.endswith('/sdapi/v1/upscalers'):
22972309
response_body = (json.dumps([]).encode())
22982310

2311+
elif self.path.endswith(('/speakers_list')): #xtts compatible
2312+
response_body = (json.dumps(["kobo","bean","corn","spicy","lime","fire","metal","potato"]).encode()) #some random voices for them to enjoy
2313+
22992314
elif self.path.endswith(('/api/tags')): #ollama compatible
23002315
response_body = (json.dumps({"models":[{"name":"koboldcpp","model":friendlymodelname,"modified_at":"2024-07-19T15:26:55.6122841+08:00","size":394998579,"digest":"b5dc5e784f2a3ee1582373093acf69a2f4e2ac1710b253a001712b86a61f88bb","details":{"parent_model":"","format":"gguf","family":"koboldcpp","families":["koboldcpp"],"parameter_size":"128M","quantization_level":"Q4_0"}}]}).encode())
23012316

@@ -2671,7 +2686,7 @@ def do_POST(self):
26712686
if self.path.endswith('/api/extra/transcribe') or self.path.endswith('/v1/audio/transcriptions'):
26722687
is_transcribe = True
26732688

2674-
if self.path.endswith('/api/extra/tts') or self.path.endswith('/v1/audio/speech'):
2689+
if self.path.endswith('/api/extra/tts') or self.path.endswith('/v1/audio/speech') or self.path.endswith('/tts_to_audio'):
26752690
is_tts = True
26762691

26772692
if is_imggen or is_transcribe or is_tts or api_format > 0:

otherarch/tts_adapter.cpp

Lines changed: 193 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -369,17 +369,56 @@ static std::vector<llama_token> prepare_guide_tokens(const llama_model * model,
369369

370370
// Add the last part
371371
std::string current_word = str.substr(start);
372-
auto tmp = common_tokenize(model, current_word, false, true);
373-
result.push_back(tmp[0]);
372+
if(current_word!="")
373+
{
374+
auto tmp = common_tokenize(model, current_word, false, true);
375+
if(tmp.size()>0){
376+
result.push_back(tmp[0]);
377+
}
378+
}
374379
return result;
375380
}
376381

382+
std::string trim_words(const std::string& input, const std::string& separator, size_t maxWords) {
383+
// Split the input string by the separator
384+
std::vector<std::string> words;
385+
size_t start = 0, end;
386+
while ((end = input.find(separator, start)) != std::string::npos) {
387+
std::string last = input.substr(start, end - start);
388+
if (last != "") {
389+
words.push_back(last);
390+
}
391+
start = end + separator.length();
392+
}
393+
std::string last = input.substr(start);
394+
if(last!="")
395+
{
396+
words.push_back(last); // Add the last word
397+
}
398+
399+
// Ensure no more than maxWords are kept
400+
if (words.size() > maxWords) {
401+
words.resize(maxWords);
402+
}
403+
404+
// Reconstruct the string with the separator
405+
std::ostringstream result;
406+
for (size_t i = 0; i < words.size(); ++i) {
407+
if (i > 0) result << separator;
408+
result << words[i];
409+
}
410+
411+
return result.str();
412+
}
413+
377414
static llama_context * ttc_ctx = nullptr; //text to codes ctx
378415
static llama_context * cts_ctx = nullptr; //codes to speech
379416

380417
static int ttsdebugmode = 0;
381418
static std::string ttsplatformenv, ttsdeviceenv, ttsvulkandeviceenv;
382419
static std::string last_generated_audio = "";
420+
static std::vector<llama_token> last_speaker_codes; //will store cached speaker
421+
static int last_speaker_seed = -999;
383422

384423
bool ttstype_load_model(const tts_load_model_inputs inputs)
385424
{
@@ -484,14 +523,11 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
484523
const llama_model * model_cts = &(cts_ctx->model);
485524
const int ttc_n_vocab = llama_n_vocab(model_ttc);
486525
std::string prompt = inputs.prompt;
487-
488-
if(!inputs.quiet)
489-
{
490-
printf("\nTTS Generating... ");
491-
}
526+
const std::string sampletext = "but<|text_sep|>that<|text_sep|>is<|text_sep|>what<|text_sep|>it<|text_sep|>is";
492527

493528
// process prompt and generate voice codes
494-
529+
llama_kv_cache_clear(ttc_ctx);
530+
llama_kv_cache_clear(cts_ctx);
495531
std::vector<llama_token> prompt_inp;
496532
prompt_init(prompt_inp, model_ttc);
497533
prompt_add(prompt_inp, model_ttc, "<|text_start|>", false, true);
@@ -501,39 +537,38 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
501537
if (speaker_seed <= 0 || speaker_seed==0xFFFFFFFF)
502538
{
503539
speaker_seed = (((uint32_t)time(NULL)) % 1000000u);
504-
if(ttsdebugmode==1)
505-
{
506-
printf("\nUsing Speaker Seed: %d", speaker_seed);
507-
}
508540
}
509541
if (audio_seed <= 0 || audio_seed==0xFFFFFFFF)
510542
{
511543
audio_seed = (((uint32_t)time(NULL)) % 1000000u);
512-
if(ttsdebugmode==1)
513-
{
514-
printf("\nUsing Audio Seed: %d", audio_seed);
515-
}
544+
}
545+
if(ttsdebugmode==1)
546+
{
547+
printf("\nUsing Speaker Seed: %d", speaker_seed);
548+
printf("\nUsing Audio Seed: %d", audio_seed);
516549
}
517550

518551
std::mt19937 tts_rng(audio_seed);
519552
std::mt19937 speaker_rng(speaker_seed);
520553

521-
//add the speaker based on the seed
522-
if(speaker_seed>0)
523-
{
524-
std::string sampletext = "but<|text_sep|>that<|text_sep|>is<|text_sep|>what<|text_sep|>it<|text_sep|>is<|text_sep|>";
525-
}
554+
int n_decode = 0;
555+
int n_predict = 2048; //will be updated later
556+
bool next_token_uses_guide_token = true;
526557

527558
// convert the input text into the necessary format expected by OuteTTS
528559
std::string prompt_clean = process_text(prompt);
529560

561+
//further clean it by keeping only the last 300 words
562+
prompt_clean = trim_words(prompt_clean,"<|text_sep|>",300);
563+
530564
if(prompt_clean.size()==0)
531565
{
532566
//no input
533567
if(!inputs.quiet)
534568
{
535569
printf("\nTTS sent empty input.\n");
536-
output.data = "";
570+
last_generated_audio = "";
571+
output.data = last_generated_audio.c_str();
537572
output.status = 1;
538573
return output;
539574
}
@@ -544,19 +579,130 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
544579
printf("\nInput: %s\n", prompt_clean.c_str());
545580
}
546581

582+
//2 passes. first pass, we generate the speaker voice if required, then cache it for reuse
583+
//second pass, we use the speaker snipper to align output voice to match the desired speaker
584+
if(speaker_seed>0) //first pass
585+
{
586+
//if we have a cached speaker, reuse it
587+
if(last_speaker_seed==speaker_seed && !last_speaker_codes.empty())
588+
{
589+
//able to proceed, do nothing
590+
if(!inputs.quiet && ttsdebugmode==1)
591+
{
592+
printf("\nReuse speaker ID=%d (%d tokens)...", last_speaker_seed, last_speaker_codes.size());
593+
}
594+
} else {
595+
//generate the voice texture of our new speaker
596+
last_speaker_codes.clear();
597+
guide_tokens = prepare_guide_tokens(model_ttc,sampletext);
598+
prompt_add(prompt_inp, model_ttc, sampletext, false, true);
599+
prompt_add(prompt_inp, model_ttc, "<|text_end|>\n<|audio_start|>\n", false, true);
600+
if(!inputs.quiet && ttsdebugmode==1)
601+
{
602+
printf("\nPrepare new speaker (%d input tokens)...", prompt_inp.size());
603+
}
604+
kcpp_embd_batch tts_batch = kcpp_embd_batch(prompt_inp, 0, false, true);
605+
auto evalok = (llama_decode(ttc_ctx, tts_batch.batch)==0);
606+
if (!evalok) {
607+
printf("\nError: TTS prompt batch processing failed\n");
608+
output.data = "";
609+
output.status = 0;
610+
return output;
611+
}
612+
613+
while (n_decode <= n_predict)
614+
{
615+
float * logits = llama_get_logits(ttc_ctx);
616+
617+
//use creative settings to generate speakers
618+
const int topk = 20;
619+
const float temp = 1.2f;
620+
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,topk,temp,speaker_rng);
621+
622+
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
623+
if(next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
624+
{
625+
if(!guide_tokens.empty())
626+
{
627+
llama_token guide_token = guide_tokens[0];
628+
guide_tokens.erase(guide_tokens.begin());
629+
new_token_id = guide_token; //ensure correct word fragment is used
630+
} else {
631+
n_decode = n_predict; //stop generation
632+
}
633+
}
634+
635+
//this is the token id that always precedes a new word
636+
next_token_uses_guide_token = (new_token_id == 198);
637+
last_speaker_codes.push_back(new_token_id);
638+
639+
// is it an end of generation? -> mark the stream as finished
640+
if (llama_token_is_eog(model_ttc, new_token_id) || n_decode >= n_predict) {
641+
break;
642+
}
643+
644+
n_decode += 1;
645+
std::vector<llama_token> next = {new_token_id};
646+
llama_batch batch = llama_batch_get_one(next.data(), next.size());
647+
648+
// evaluate the current batch with the transformer model
649+
if (llama_decode(ttc_ctx, batch)) {
650+
printf("\nError: TTS code generation failed!\n");
651+
output.data = "";
652+
output.status = 0;
653+
return output;
654+
}
655+
}
656+
657+
//trim everything after final <|code_end|>
658+
auto it = std::find(last_speaker_codes.rbegin(), last_speaker_codes.rend(), 151670);
659+
if (it != last_speaker_codes.rend()) {
660+
// Erase elements after the found 999 (inclusive)
661+
last_speaker_codes.erase(it.base(), last_speaker_codes.end());
662+
}
663+
last_speaker_seed = speaker_seed;
664+
if(!inputs.quiet && ttsdebugmode==1)
665+
{
666+
printf("\nNew speaker ID=%d created (%d tokens)...", last_speaker_seed, last_speaker_codes.size());
667+
const std::string inp_txt = common_detokenize(ttc_ctx, last_speaker_codes, true);
668+
printf("\n%s\n", inp_txt.c_str());
669+
}
670+
}
671+
guide_tokens.clear();
672+
llama_kv_cache_clear(ttc_ctx);
673+
prompt_init(prompt_inp, model_ttc);
674+
prompt_add(prompt_inp, model_ttc, "<|text_start|>", false, true);
675+
next_token_uses_guide_token = true;
676+
}
677+
678+
//second pass: add the speaker before the actual prompt
547679
guide_tokens = prepare_guide_tokens(model_ttc,prompt_clean);
680+
if(speaker_seed > 0)
681+
{
682+
prompt_clean = sampletext + "<|text_sep|>" + prompt_clean;
683+
}
548684
prompt_add(prompt_inp, model_ttc, prompt_clean, false, true);
549685

550686
if(!inputs.quiet)
551687
{
552-
printf(" (%d input words)...", guide_tokens.size());
688+
printf("\nTTS Generating (%d input tokens)...", prompt_inp.size());
689+
}
690+
691+
prompt_add(prompt_inp, model_ttc, "<|text_end|>\n<|audio_start|>\n", false, true);
692+
693+
if(!last_speaker_codes.empty() && speaker_seed > 0) //apply speaker voice output
694+
{
695+
prompt_add(prompt_inp, last_speaker_codes);
553696
}
554697

555-
prompt_add(prompt_inp, model_ttc, "<|text_end|>\n", false, true);
698+
if(!inputs.quiet && ttsdebugmode==1)
699+
{
700+
printf("\nDUMP TTS PROMPT (%d tokens):\n", prompt_inp.size());
701+
const std::string inp_txt = common_detokenize(ttc_ctx, prompt_inp, true);
702+
printf("\n%s\n", inp_txt.c_str());
703+
}
556704

557705
//create batch with tokens for decoding prompt processing
558-
llama_kv_cache_clear(ttc_ctx);
559-
llama_kv_cache_clear(cts_ctx);
560706
kcpp_embd_batch tts_batch = kcpp_embd_batch(prompt_inp, 0, false, true);
561707

562708
auto evalok = (llama_decode(ttc_ctx, tts_batch.batch)==0);
@@ -568,28 +714,33 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
568714
}
569715

570716
// main loop
571-
int n_decode = 0;
572-
int n_predict = 4096; //max 4096 tokens
573-
574-
bool next_token_uses_guide_token = true;
717+
n_decode = 0;
718+
n_predict = 4096; //max 4096 tokens
575719

576720
while (n_decode <= n_predict)
577721
{
578722
float * logits = llama_get_logits(ttc_ctx);
579723

580-
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,20,1.0,tts_rng);
724+
//use predictable settings to generate voice
725+
const int topk = 4;
726+
const float temp = 0.75f;
727+
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,topk,temp,tts_rng);
581728

582729
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
583-
if(!guide_tokens.empty() && next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
730+
if(next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
584731
{
585-
llama_token guide_token = guide_tokens[0];
586-
guide_tokens.erase(guide_tokens.begin());
587-
new_token_id = guide_token; //ensure correct word fragment is used
732+
if(!guide_tokens.empty())
733+
{
734+
llama_token guide_token = guide_tokens[0];
735+
guide_tokens.erase(guide_tokens.begin());
736+
new_token_id = guide_token; //ensure correct word fragment is used
737+
} else {
738+
n_decode = n_predict; //end generation
739+
}
588740
}
589741

590742
//this is the token id that always precedes a new word
591743
next_token_uses_guide_token = (new_token_id == 198);
592-
593744
codes.push_back(new_token_id);
594745

595746
// is it an end of generation? -> mark the stream as finished
@@ -613,7 +764,6 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
613764
if(!inputs.quiet && ttsdebugmode==1)
614765
{
615766
const std::string inp_txt = common_detokenize(ttc_ctx, codes, true);
616-
617767
printf("\nGenerated %d Codes: '%s'\n",codes.size(), inp_txt.c_str());
618768
}
619769

@@ -628,8 +778,9 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
628778
if(n_codes<=1)
629779
{
630780
printf("\nWarning: TTS vocoder generated nothing!\n");
631-
output.data = "";
632-
output.status = 0;
781+
last_generated_audio = "";
782+
output.data = last_generated_audio.c_str();
783+
output.status = 1;
633784
return output;
634785
}
635786
kcpp_embd_batch codebatch = kcpp_embd_batch(codes,0,false,true);
@@ -649,8 +800,9 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
649800

650801
const int n_sr = 24000; // sampling rate
651802

652-
// zero out first 0.05 seconds
653-
for (int i = 0; i < 24000/20; ++i) {
803+
// zero out first 0.25 seconds or 0.05 depending on whether its seeded
804+
const int cutout = (speaker_seed>0?(24000/4):(24000/20));
805+
for (int i = 0; i < cutout; ++i) {
654806
audio[i] = 0.0f;
655807
}
656808
//add some silence at the end

0 commit comments

Comments
 (0)