Skip to content

Commit d39b6f9

Browse files
committed
simple-tts
1 parent 6b25ae9 commit d39b6f9

File tree

1 file changed

+299
-0
lines changed

1 file changed

+299
-0
lines changed

examples/simple-tts/simple-tts.cpp

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,196 @@ void batch_add(struct llama_batch & batch, llama_token id,llama_pos pos, const s
322322
batch.n_tokens++;
323323
}
324324

325+
static void save_wav16(const std::string & fname, const std::vector<float> & data, int sample_rate) {
326+
std::ofstream file(fname, std::ios::binary);
327+
if (!file) {
328+
fprintf(stderr, "%s: Failed to open file '%s' for writing", __func__, fname.c_str());
329+
return;
330+
}
331+
332+
wav_header header;
333+
header.sample_rate = sample_rate;
334+
header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8);
335+
header.block_align = header.num_channels * (header.bits_per_sample / 8);
336+
header.data_size = data.size() * (header.bits_per_sample / 8);
337+
header.chunk_size = 36 + header.data_size;
338+
339+
file.write(reinterpret_cast<const char*>(&header), sizeof(header));
340+
341+
for (const auto & sample : data) {
342+
int16_t pcm_sample = static_cast<int16_t>(std::clamp(sample * 32767.0, -32768.0, 32767.0));
343+
file.write(reinterpret_cast<const char*>(&pcm_sample), sizeof(pcm_sample));
344+
}
345+
346+
file.close();
347+
}
348+
349+
static void fill_hann_window(int length, bool periodic, float * output) {
350+
int offset = -1;
351+
if (periodic) {
352+
offset = 0;
353+
}
354+
for (int i = 0; i < length; i++) {
355+
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
356+
}
357+
}
358+
359+
// very poor-man fft
360+
static void twiddle(float * real, float * imag, int k, int N) {
361+
float angle = 2 * M_PI * k / N;
362+
*real = cos(angle);
363+
*imag = sin(angle);
364+
}
365+
366+
static void irfft(int n, const float * inp_cplx, float * out_real) {
367+
int N = n / 2 + 1;
368+
369+
std::vector<float> real_input(N);
370+
std::vector<float> imag_input(N);
371+
for (int i = 0; i < N; ++i) {
372+
real_input[i] = inp_cplx[2 * i];
373+
imag_input[i] = inp_cplx[2 * i + 1];
374+
}
375+
376+
std::vector<float> real_output(n);
377+
std::vector<float> imag_output(n);
378+
379+
for (int k = 0; k < n; ++k) {
380+
real_output[k] = 0.0f;
381+
imag_output[k] = 0.0f;
382+
for (int m = 0; m < N; ++m) {
383+
float twiddle_real;
384+
float twiddle_imag;
385+
386+
twiddle(&twiddle_real, &twiddle_imag, k * m, n);
387+
388+
real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag;
389+
imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real;
390+
}
391+
}
392+
393+
for (int i = 0; i < n; ++i) {
394+
out_real[i] = real_output[i] / N;
395+
}
396+
}
397+
398+
//
399+
// y = torch.nn.functional.fold(
400+
// data, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
401+
// )[:, 0, 0, pad:-pad]
402+
//
403+
// data.shape = torch.Size([1, 1280, 261])
404+
// output_size = 84480
405+
// win_length = 1280
406+
// hop_length = 320
407+
// pad = 480
408+
//
409+
static void fold(const std::vector<float> & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector<float> & output) {
410+
int64_t output_height = n_out;
411+
int64_t kernel_w = n_win;
412+
int64_t stride_w = n_hop;
413+
int64_t width = n_out;
414+
415+
output.resize(width, 0.0f);
416+
417+
int64_t col_idx = 0;
418+
for (int64_t w_col = 0; w_col < width; ++w_col) {
419+
int64_t start = w_col * stride_w - n_pad;
420+
int64_t end = start + kernel_w;
421+
422+
for (int64_t w_im = start; w_im < end; ++w_im) {
423+
if (w_im >= 0 && w_im < output_height && col_idx < (int64_t) data.size()) {
424+
output[w_im] += data[col_idx];
425+
}
426+
col_idx++;
427+
}
428+
}
429+
430+
output.resize(n_out - 2 * n_pad);
431+
}
432+
433+
// TODO: not optimized at all
434+
static std::vector<float> embd_to_audio(
435+
const float * embd,
436+
const int n_codes,
437+
const int n_embd,
438+
const int n_thread) {
439+
const int n_fft = 1280;
440+
const int n_hop = 320;
441+
const int n_win = 1280;
442+
const int n_pad = (n_win - n_hop)/2;
443+
const int n_out = (n_codes - 1)*n_hop + n_win;
444+
445+
std::vector<float> hann(n_fft);
446+
447+
fill_hann_window(hann.size(), true, hann.data());
448+
449+
int n_spec = n_embd*n_codes;
450+
451+
std::vector<float> E (n_spec);
452+
std::vector<float> S (n_spec);
453+
std::vector<float> ST(n_spec);
454+
455+
for (int l = 0; l < n_codes; ++l) {
456+
for (int k = 0; k < n_embd; ++k) {
457+
E[k*n_codes + l] = embd[l*n_embd + k];
458+
}
459+
}
460+
461+
for (int k = 0; k < n_embd/2; ++k) {
462+
for (int l = 0; l < n_codes; ++l) {
463+
float mag = E[(k )*n_codes + l];
464+
float phi = E[(k + n_embd/2)*n_codes + l];
465+
466+
mag = exp(mag);
467+
468+
if (mag > 1e2) {
469+
mag = 1e2;
470+
}
471+
S[2*(k*n_codes + l) + 0] = mag*cosf(phi);
472+
S[2*(k*n_codes + l) + 1] = mag*sinf(phi);
473+
}
474+
}
475+
476+
for (int l = 0; l < n_codes; ++l) {
477+
for (int k = 0; k < n_embd/2; ++k) {
478+
ST[l*n_embd + 2*k + 0] = S[2*(k*n_codes + l) + 0];
479+
ST[l*n_embd + 2*k + 1] = S[2*(k*n_codes + l) + 1];
480+
}
481+
}
482+
483+
std::vector<float> res (n_codes*n_fft);
484+
std::vector<float> hann2(n_codes*n_fft);
485+
486+
std::vector<std::thread> workers(n_thread);
487+
for (int i = 0; i < n_thread; ++i) {
488+
workers[i] = std::thread([&, i]() {
489+
for (int l = i; l < n_codes; l += n_thread) {
490+
irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft);
491+
for (int j = 0; j < n_fft; ++j) {
492+
res [l*n_fft + j] *= hann[j];
493+
hann2[l*n_fft + j] = hann[j] * hann[j];
494+
}
495+
}
496+
});
497+
}
498+
for (int i = 0; i < n_thread; ++i) {
499+
workers[i].join();
500+
}
501+
502+
std::vector<float> audio;
503+
std::vector<float> env;
504+
505+
fold(res, n_out, n_win, n_hop, n_pad, audio);
506+
fold(hann2, n_out, n_win, n_hop, n_pad, env); // TODO: can be done once
507+
508+
for (size_t i = 0; i < audio.size(); ++i) {
509+
audio[i] /= env[i];
510+
}
511+
512+
return audio;
513+
}
514+
325515
static void print_usage(int, char ** argv) {
326516
printf("\nexample usage:\n");
327517
printf("\n %s -m model.gguf -mv vocoder.gguf -v en_male_1.json -p \"Hello!\"\n", argv[0]);
@@ -476,5 +666,114 @@ int main(int argc, char ** argv) {
476666

477667
llama_synchronize(ctx);
478668

669+
// main loop
670+
671+
// remember the batch index of the last token for each parallel sequence
672+
// we need this to determine which logits to sample from
673+
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
674+
675+
int n_past = batch.n_tokens;
676+
int n_decode = 0;
677+
678+
bool next_token_uses_guide_token = true;
679+
479680
std::vector<llama_token> codes;
681+
682+
while (n_decode <= n_predict) {
683+
batch.n_tokens = 0;
684+
685+
// sample the next token for each parallel sequence / stream
686+
for (int32_t i = 0; i < n_parallel; ++i) {
687+
if (i_batch[i] < 0) {
688+
// the stream has already finished
689+
continue;
690+
}
691+
692+
llama_token new_token_id = llama_sampler_sample(&samplers[i], ctx, i_batch[i]);
693+
694+
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
695+
if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) {
696+
llama_token guide_token = guide_tokens[0];
697+
guide_tokens.erase(guide_tokens.begin());
698+
new_token_id = guide_token; //ensure correct word fragment is used
699+
}
700+
701+
//this is the token id that always precedes a new word
702+
next_token_uses_guide_token = (new_token_id == 198);
703+
704+
llama_sampler_accept(&samplers[i], new_token_id);
705+
706+
codes.push_back(new_token_id);
707+
708+
if (llama_vocab_is_eog(vocab, new_token_id) || n_decode == n_predict) {
709+
// Mark the stream as finished
710+
i_batch[i] = -1;
711+
continue;
712+
}
713+
714+
i_batch[i] = batch.n_tokens;
715+
716+
batch_add(batch, new_token_id, n_past, { i }, true);
717+
}
718+
719+
// all streams are finished
720+
if (batch.n_tokens == 0) {
721+
break;
722+
}
723+
724+
n_decode += 1;
725+
n_past += 1;
726+
727+
// evaluate the current batch with the transformer model
728+
if (llama_decode(ctx, batch)) {
729+
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
730+
return 1;
731+
}
732+
}
733+
734+
llama_batch_free(batch);
735+
736+
// remove all non-audio tokens (i.e. < 151672 || > 155772)
737+
codes.erase(std::remove_if(codes.begin(), codes.end(), [](llama_token t) { return t < 151672 || t > 155772; }), codes.end());
738+
739+
for (auto & token : codes) {
740+
token -= 151672;
741+
}
742+
743+
const int n_codes = codes.size();
744+
745+
llama_batch batch = llama_batch_init(n_codes, 0, 1);
746+
747+
for (size_t i = 0; i < codes.size(); ++i) {
748+
batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits?
749+
}
750+
751+
// evaluate the current batch with the transformer model
752+
if (llama_decode(ctx_vocoder, batch)) {
753+
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
754+
return 1;
755+
}
756+
757+
llama_synchronize(ctx_vocoder);
758+
759+
// spectral operations
760+
const int n_embd = llama_model_n_embd(vocoder);
761+
const float * embd = llama_get_embeddings(ctx_vocoder);
762+
763+
auto audio = embd_to_audio(embd, n_codes, n_embd, ctx_params.n_threads);
764+
765+
const std::string fname = "output.wav";
766+
767+
const int n_sr = 24000; // sampling rate
768+
769+
// zero out first 0.25 seconds
770+
for (int i = 0; i < 24000/4; ++i) {
771+
audio[i] = 0.0f;
772+
}
773+
774+
save_wav16(fname, audio, n_sr);
775+
776+
llama_backend_free();
777+
778+
return 0;
480779
}

0 commit comments

Comments
 (0)