@@ -322,6 +322,196 @@ void batch_add(struct llama_batch & batch, llama_token id,llama_pos pos, const s
322322 batch.n_tokens ++;
323323}
324324
325+ static void save_wav16 (const std::string & fname, const std::vector<float > & data, int sample_rate) {
326+ std::ofstream file (fname, std::ios::binary);
327+ if (!file) {
328+ fprintf (stderr, " %s: Failed to open file '%s' for writing" , __func__, fname.c_str ());
329+ return ;
330+ }
331+
332+ wav_header header;
333+ header.sample_rate = sample_rate;
334+ header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8 );
335+ header.block_align = header.num_channels * (header.bits_per_sample / 8 );
336+ header.data_size = data.size () * (header.bits_per_sample / 8 );
337+ header.chunk_size = 36 + header.data_size ;
338+
339+ file.write (reinterpret_cast <const char *>(&header), sizeof (header));
340+
341+ for (const auto & sample : data) {
342+ int16_t pcm_sample = static_cast <int16_t >(std::clamp (sample * 32767.0 , -32768.0 , 32767.0 ));
343+ file.write (reinterpret_cast <const char *>(&pcm_sample), sizeof (pcm_sample));
344+ }
345+
346+ file.close ();
347+ }
348+
349+ static void fill_hann_window (int length, bool periodic, float * output) {
350+ int offset = -1 ;
351+ if (periodic) {
352+ offset = 0 ;
353+ }
354+ for (int i = 0 ; i < length; i++) {
355+ output[i] = 0.5 * (1.0 - cosf ((2.0 * M_PI * i) / (length + offset)));
356+ }
357+ }
358+
359+ // very poor-man fft
360+ static void twiddle (float * real, float * imag, int k, int N) {
361+ float angle = 2 * M_PI * k / N;
362+ *real = cos (angle);
363+ *imag = sin (angle);
364+ }
365+
366+ static void irfft (int n, const float * inp_cplx, float * out_real) {
367+ int N = n / 2 + 1 ;
368+
369+ std::vector<float > real_input (N);
370+ std::vector<float > imag_input (N);
371+ for (int i = 0 ; i < N; ++i) {
372+ real_input[i] = inp_cplx[2 * i];
373+ imag_input[i] = inp_cplx[2 * i + 1 ];
374+ }
375+
376+ std::vector<float > real_output (n);
377+ std::vector<float > imag_output (n);
378+
379+ for (int k = 0 ; k < n; ++k) {
380+ real_output[k] = 0 .0f ;
381+ imag_output[k] = 0 .0f ;
382+ for (int m = 0 ; m < N; ++m) {
383+ float twiddle_real;
384+ float twiddle_imag;
385+
386+ twiddle (&twiddle_real, &twiddle_imag, k * m, n);
387+
388+ real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag;
389+ imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real;
390+ }
391+ }
392+
393+ for (int i = 0 ; i < n; ++i) {
394+ out_real[i] = real_output[i] / N;
395+ }
396+ }
397+
398+ //
399+ // y = torch.nn.functional.fold(
400+ // data, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
401+ // )[:, 0, 0, pad:-pad]
402+ //
403+ // data.shape = torch.Size([1, 1280, 261])
404+ // output_size = 84480
405+ // win_length = 1280
406+ // hop_length = 320
407+ // pad = 480
408+ //
409+ static void fold (const std::vector<float > & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector<float > & output) {
410+ int64_t output_height = n_out;
411+ int64_t kernel_w = n_win;
412+ int64_t stride_w = n_hop;
413+ int64_t width = n_out;
414+
415+ output.resize (width, 0 .0f );
416+
417+ int64_t col_idx = 0 ;
418+ for (int64_t w_col = 0 ; w_col < width; ++w_col) {
419+ int64_t start = w_col * stride_w - n_pad;
420+ int64_t end = start + kernel_w;
421+
422+ for (int64_t w_im = start; w_im < end; ++w_im) {
423+ if (w_im >= 0 && w_im < output_height && col_idx < (int64_t ) data.size ()) {
424+ output[w_im] += data[col_idx];
425+ }
426+ col_idx++;
427+ }
428+ }
429+
430+ output.resize (n_out - 2 * n_pad);
431+ }
432+
433+ // TODO: not optimized at all
434+ static std::vector<float > embd_to_audio (
435+ const float * embd,
436+ const int n_codes,
437+ const int n_embd,
438+ const int n_thread) {
439+ const int n_fft = 1280 ;
440+ const int n_hop = 320 ;
441+ const int n_win = 1280 ;
442+ const int n_pad = (n_win - n_hop)/2 ;
443+ const int n_out = (n_codes - 1 )*n_hop + n_win;
444+
445+ std::vector<float > hann (n_fft);
446+
447+ fill_hann_window (hann.size (), true , hann.data ());
448+
449+ int n_spec = n_embd*n_codes;
450+
451+ std::vector<float > E (n_spec);
452+ std::vector<float > S (n_spec);
453+ std::vector<float > ST (n_spec);
454+
455+ for (int l = 0 ; l < n_codes; ++l) {
456+ for (int k = 0 ; k < n_embd; ++k) {
457+ E[k*n_codes + l] = embd[l*n_embd + k];
458+ }
459+ }
460+
461+ for (int k = 0 ; k < n_embd/2 ; ++k) {
462+ for (int l = 0 ; l < n_codes; ++l) {
463+ float mag = E[(k )*n_codes + l];
464+ float phi = E[(k + n_embd/2 )*n_codes + l];
465+
466+ mag = exp (mag);
467+
468+ if (mag > 1e2 ) {
469+ mag = 1e2 ;
470+ }
471+ S[2 *(k*n_codes + l) + 0 ] = mag*cosf (phi);
472+ S[2 *(k*n_codes + l) + 1 ] = mag*sinf (phi);
473+ }
474+ }
475+
476+ for (int l = 0 ; l < n_codes; ++l) {
477+ for (int k = 0 ; k < n_embd/2 ; ++k) {
478+ ST[l*n_embd + 2 *k + 0 ] = S[2 *(k*n_codes + l) + 0 ];
479+ ST[l*n_embd + 2 *k + 1 ] = S[2 *(k*n_codes + l) + 1 ];
480+ }
481+ }
482+
483+ std::vector<float > res (n_codes*n_fft);
484+ std::vector<float > hann2 (n_codes*n_fft);
485+
486+ std::vector<std::thread> workers (n_thread);
487+ for (int i = 0 ; i < n_thread; ++i) {
488+ workers[i] = std::thread ([&, i]() {
489+ for (int l = i; l < n_codes; l += n_thread) {
490+ irfft (n_fft, ST.data () + l*n_embd, res.data () + l*n_fft);
491+ for (int j = 0 ; j < n_fft; ++j) {
492+ res [l*n_fft + j] *= hann[j];
493+ hann2[l*n_fft + j] = hann[j] * hann[j];
494+ }
495+ }
496+ });
497+ }
498+ for (int i = 0 ; i < n_thread; ++i) {
499+ workers[i].join ();
500+ }
501+
502+ std::vector<float > audio;
503+ std::vector<float > env;
504+
505+ fold (res, n_out, n_win, n_hop, n_pad, audio);
506+ fold (hann2, n_out, n_win, n_hop, n_pad, env); // TODO: can be done once
507+
508+ for (size_t i = 0 ; i < audio.size (); ++i) {
509+ audio[i] /= env[i];
510+ }
511+
512+ return audio;
513+ }
514+
325515static void print_usage (int , char ** argv) {
326516 printf (" \n example usage:\n " );
327517 printf (" \n %s -m model.gguf -mv vocoder.gguf -v en_male_1.json -p \" Hello!\"\n " , argv[0 ]);
@@ -476,5 +666,114 @@ int main(int argc, char ** argv) {
476666
477667 llama_synchronize (ctx);
478668
669+ // main loop
670+
671+ // remember the batch index of the last token for each parallel sequence
672+ // we need this to determine which logits to sample from
673+ std::vector<int32_t > i_batch (n_parallel, batch.n_tokens - 1 );
674+
675+ int n_past = batch.n_tokens ;
676+ int n_decode = 0 ;
677+
678+ bool next_token_uses_guide_token = true ;
679+
479680 std::vector<llama_token> codes;
681+
682+ while (n_decode <= n_predict) {
683+ batch.n_tokens = 0 ;
684+
685+ // sample the next token for each parallel sequence / stream
686+ for (int32_t i = 0 ; i < n_parallel; ++i) {
687+ if (i_batch[i] < 0 ) {
688+ // the stream has already finished
689+ continue ;
690+ }
691+
692+ llama_token new_token_id = llama_sampler_sample (&samplers[i], ctx, i_batch[i]);
693+
694+ // guide tokens help prevent hallucinations by forcing the TTS to use the correct word
695+ if (!guide_tokens.empty () && next_token_uses_guide_token && !llama_vocab_is_control (vocab, new_token_id) && !llama_vocab_is_eog (vocab, new_token_id)) {
696+ llama_token guide_token = guide_tokens[0 ];
697+ guide_tokens.erase (guide_tokens.begin ());
698+ new_token_id = guide_token; // ensure correct word fragment is used
699+ }
700+
701+ // this is the token id that always precedes a new word
702+ next_token_uses_guide_token = (new_token_id == 198 );
703+
704+ llama_sampler_accept (&samplers[i], new_token_id);
705+
706+ codes.push_back (new_token_id);
707+
708+ if (llama_vocab_is_eog (vocab, new_token_id) || n_decode == n_predict) {
709+ // Mark the stream as finished
710+ i_batch[i] = -1 ;
711+ continue ;
712+ }
713+
714+ i_batch[i] = batch.n_tokens ;
715+
716+ batch_add (batch, new_token_id, n_past, { i }, true );
717+ }
718+
719+ // all streams are finished
720+ if (batch.n_tokens == 0 ) {
721+ break ;
722+ }
723+
724+ n_decode += 1 ;
725+ n_past += 1 ;
726+
727+ // evaluate the current batch with the transformer model
728+ if (llama_decode (ctx, batch)) {
729+ fprintf (stderr, " %s: llama_decode() failed\n " , __func__);
730+ return 1 ;
731+ }
732+ }
733+
734+ llama_batch_free (batch);
735+
736+ // remove all non-audio tokens (i.e. < 151672 || > 155772)
737+ codes.erase (std::remove_if (codes.begin (), codes.end (), [](llama_token t) { return t < 151672 || t > 155772 ; }), codes.end ());
738+
739+ for (auto & token : codes) {
740+ token -= 151672 ;
741+ }
742+
743+ const int n_codes = codes.size ();
744+
745+ llama_batch batch = llama_batch_init (n_codes, 0 , 1 );
746+
747+ for (size_t i = 0 ; i < codes.size (); ++i) {
748+ batch_add (batch, codes[i], i, { 0 }, true ); // TODO: all logits?
749+ }
750+
751+ // evaluate the current batch with the transformer model
752+ if (llama_decode (ctx_vocoder, batch)) {
753+ fprintf (stderr, " %s: llama_decode() failed\n " , __func__);
754+ return 1 ;
755+ }
756+
757+ llama_synchronize (ctx_vocoder);
758+
759+ // spectral operations
760+ const int n_embd = llama_model_n_embd (vocoder);
761+ const float * embd = llama_get_embeddings (ctx_vocoder);
762+
763+ auto audio = embd_to_audio (embd, n_codes, n_embd, ctx_params.n_threads );
764+
765+ const std::string fname = " output.wav" ;
766+
767+ const int n_sr = 24000 ; // sampling rate
768+
769+ // zero out first 0.25 seconds
770+ for (int i = 0 ; i < 24000 /4 ; ++i) {
771+ audio[i] = 0 .0f ;
772+ }
773+
774+ save_wav16 (fname, audio, n_sr);
775+
776+ llama_backend_free ();
777+
778+ return 0 ;
480779}
0 commit comments