Skip to content

Commit fb1274e

Browse files
committed
improved sampling for tts and fixed yet another bug. no patch release for this.
1 parent cca4a93 commit fb1274e

File tree

3 files changed

+54
-12
lines changed

3 files changed

+54
-12
lines changed

otherarch/tts_adapter.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -754,8 +754,9 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
754754

755755
//use creative settings to generate speakers
756756
const int topk = 20;
757+
const float top_p = 1.0f;
757758
const float temp = 1.2f;
758-
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,topk,temp,speaker_rng);
759+
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,std::vector<int32_t>(),1.0,top_p,topk,temp,speaker_rng);
759760

760761
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
761762
if(next_token_uses_guide_token && !llama_vocab_is_control(ttcvocab, new_token_id) && !llama_vocab_is_eog(ttcvocab, new_token_id))
@@ -876,7 +877,8 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
876877
//use predictable settings to generate voice
877878
const int topk = 4;
878879
const float temp = 0.75f;
879-
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,topk,temp,tts_rng);
880+
const float top_p = 1.0f;
881+
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,std::vector<int32_t>(),1.0,top_p,topk,temp,speaker_rng);
880882

881883
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
882884
if(next_token_uses_guide_token && !llama_vocab_is_control(ttcvocab, new_token_id) && !llama_vocab_is_eog(ttcvocab, new_token_id))
@@ -933,7 +935,7 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
933935
const int n_codes = codes.size();
934936
if(n_codes<=1)
935937
{
936-
printf("\nWarning: TTS vocoder generated nothing!\n");
938+
printf("\nWarning: No Audio Tokens Produced!\n");
937939
last_generated_audio = "";
938940
output.data = last_generated_audio.c_str();
939941
output.status = 1;
@@ -963,12 +965,23 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
963965

964966
//audio = resample_wav(audio,n_sr,t_sr); //resample to 16k
965967

966-
for (int i = 0; i < cutout; ++i) {
967-
audio[i] = 0.0f;
968+
if(audio.size()>cutout+16)
969+
{
970+
for (int i = 0; i < cutout; ++i) {
971+
audio[i] = 0.0f;
972+
}
973+
//add some silence at the end
974+
for (int i = 0; i < cutout; ++i) {
975+
audio.push_back(0.0f);
976+
}
968977
}
969-
//add some silence at the end
970-
for (int i = 0; i < cutout; ++i) {
971-
audio.push_back(0.0f);
978+
else
979+
{
980+
printf("\nWarning: TTS vocoder generated nothing!\n");
981+
last_generated_audio = "";
982+
output.data = last_generated_audio.c_str();
983+
output.status = 1;
984+
return output;
972985
}
973986

974987
last_generated_audio = save_wav16_base64(audio, t_sr);

otherarch/utils.cpp

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,9 +369,9 @@ std::vector<float> resample_wav(const std::vector<float>& input, uint32_t input_
369369
}
370370

371371
//a very rudimentary all in one sampling function which has no dependencies
372-
int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float temp, std::mt19937 & rng)
372+
int32_t kcpp_quick_sample(float * logits, const int n_logits, const std::vector<int32_t> & last_n_tokens, float rep_pen, float top_p, int top_k, float temp, std::mt19937 & rng)
373373
{
374-
if (temp <= 0 || top_k==1) {
374+
if (temp <= 0) {
375375
// select the token with the highest logit directly
376376
float max_logit = logits[0];
377377
int32_t max_id = 0;
@@ -392,8 +392,19 @@ int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float t
392392

393393
//temperature sample
394394
const float scale = 1.0f/temp;
395+
396+
//sample rep pen
395397
for (int i = 0; i < n_logits; ++i) {
396-
logits_id.push_back(std::make_pair(logits[i]*scale, i));
398+
if (rep_pen>1.0f && std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
399+
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
400+
if (logits[i] < 0.0f) {
401+
logits_id.push_back(std::make_pair(logits[i]*scale*rep_pen, i));
402+
} else {
403+
logits_id.push_back(std::make_pair(logits[i]*scale/rep_pen, i));
404+
}
405+
} else {
406+
logits_id.push_back(std::make_pair(logits[i]*scale, i));
407+
}
397408
}
398409

399410
//sample top_k
@@ -421,6 +432,24 @@ int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float t
421432
p /= sum;
422433
}
423434

435+
//apply top p
436+
if (top_p < 1.0) {
437+
double cumsum = 0.0;
438+
for (int i = 0; i < (int) probs.size(); i++) {
439+
cumsum += probs[i];
440+
if (cumsum >= top_p) {
441+
probs.resize(i + 1);
442+
logits_id.resize(i + 1);
443+
break;
444+
}
445+
}
446+
}
447+
448+
// normalize the probs
449+
for (auto & p : probs) {
450+
p /= sum;
451+
}
452+
424453
std::discrete_distribution<> dist(probs.begin(), probs.end());
425454
int idx = dist(rng);
426455

otherarch/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ std::string kcpp_base64_encode(const std::string &data);
6363
std::string get_timestamp_str();
6464
std::vector<float> resample_wav(const std::vector<float>& input, uint32_t input_rate, uint32_t output_rate);
6565

66-
int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float temp, std::mt19937 & rng);
66+
int32_t kcpp_quick_sample(float * logits, const int n_logits, const std::vector<int32_t> & last_n_tokens, float rep_pen, float top_p, int top_k, float temp, std::mt19937 & rng);
6767

6868
struct kcpp_embd_batch { //duplcated from llava_embd_batch
6969
std::vector<int32_t> pos;

0 commit comments

Comments
 (0)