diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index 4cc42e1674ccc..4cd5c7fc4c454 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -128,6 +128,34 @@ static void twiddle(float * real, float * imag, int k, int N) { *imag = sin(angle); } +struct theta { + theta() : cosine(0.0f), sine(0.0f) {} + + void reset(int k, int N) { + twiddle(&cosine, &sine, k, N); + } + + const float & get_cos() const {return cosine;} + const float & get_sin() const {return sine;} + +private: + float cosine; + float sine; +}; + +static std::vector thetas; + +static void fill_thetas(int N) { + thetas.resize(N); + for (int k = 0; k < N; ++k) { + thetas[k].reset(k, N); + } +} + +static const theta& get_theta(int k) { + return thetas[k % thetas.size()]; +} + static void irfft(int n, const float * inp_cplx, float * out_real) { int N = n / 2 + 1; @@ -160,6 +188,19 @@ static void irfft(int n, const float * inp_cplx, float * out_real) { } } +static void irfft_2(int n, const float * inp_cplx, float * out_real) { + int N = n / 2 + 1; + + for (int k = 0; k < n; ++k) { + out_real[k] = 0.0f; + for (int m = 0; m < N; ++m) { + const theta & t = get_theta(k * m); + out_real[k] += inp_cplx[2 * m] * t.get_cos() - inp_cplx[2 * m + 1] * t.get_sin(); + } + out_real[k] /= N; + } +} + // // y = torch.nn.functional.fold( // data, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), @@ -207,6 +248,8 @@ static std::vector embd_to_audio( const int n_pad = (n_win - n_hop)/2; const int n_out = (n_codes - 1)*n_hop + n_win; + fill_thetas(n_fft); + std::vector hann(n_fft); fill_hann_window(hann.size(), true, hann.data()); @@ -248,11 +291,13 @@ static std::vector embd_to_audio( std::vector res (n_codes*n_fft); std::vector hann2(n_codes*n_fft); + const auto t_irfft_start = ggml_time_us(); + std::vector workers(n_thread); for (int i = 0; i < n_thread; ++i) { workers[i] = std::thread([&, i]() { for (int l = i; l < n_codes; l += n_thread) { - irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft); + irfft_2(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft); for (int j = 0; j < n_fft; ++j) { res [l*n_fft + j] *= hann[j]; hann2[l*n_fft + j] = hann[j] * hann[j]; @@ -264,6 +309,8 @@ static std::vector embd_to_audio( workers[i].join(); } + LOG_INF("%s: time irfft: %.3f ms\n", __func__, (ggml_time_us() - t_irfft_start) / 1000.0f); + std::vector audio; std::vector env;