@@ -128,6 +128,34 @@ static void twiddle(float * real, float * imag, int k, int N) {
128128 *imag = sin (angle);
129129}
130130
131+ struct theta {
132+ theta () : cosine(0 .0f ), sine(0 .0f ) {}
133+
134+ void reset (int k, int N) {
135+ twiddle (&cosine, &sine, k, N);
136+ }
137+
138+ const float & get_cos () const {return cosine;}
139+ const float & get_sin () const {return sine;}
140+
141+ private:
142+ float cosine;
143+ float sine;
144+ };
145+
146+ static std::vector<theta> thetas;
147+
148+ static void fill_thetas (int N) {
149+ thetas.resize (N);
150+ for (int k = 0 ; k < N; ++k) {
151+ thetas[k].reset (k, N);
152+ }
153+ }
154+
155+ static const theta& get_theta (int k) {
156+ return thetas[k % thetas.size ()];
157+ }
158+
131159static void irfft (int n, const float * inp_cplx, float * out_real) {
132160 int N = n / 2 + 1 ;
133161
@@ -160,6 +188,19 @@ static void irfft(int n, const float * inp_cplx, float * out_real) {
160188 }
161189}
162190
191+ static void irfft_2 (int n, const float * inp_cplx, float * out_real) {
192+ int N = n / 2 + 1 ;
193+
194+ for (int k = 0 ; k < n; ++k) {
195+ out_real[k] = 0 .0f ;
196+ for (int m = 0 ; m < N; ++m) {
197+ const theta & t = get_theta (k * m);
198+ out_real[k] += inp_cplx[2 * m] * t.get_cos () - inp_cplx[2 * m + 1 ] * t.get_sin ();
199+ }
200+ out_real[k] /= N;
201+ }
202+ }
203+
163204//
164205// y = torch.nn.functional.fold(
165206// data, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
@@ -207,6 +248,8 @@ static std::vector<float> embd_to_audio(
207248 const int n_pad = (n_win - n_hop)/2 ;
208249 const int n_out = (n_codes - 1 )*n_hop + n_win;
209250
251+ fill_thetas (n_fft);
252+
210253 std::vector<float > hann (n_fft);
211254
212255 fill_hann_window (hann.size (), true , hann.data ());
@@ -248,11 +291,13 @@ static std::vector<float> embd_to_audio(
248291 std::vector<float > res (n_codes*n_fft);
249292 std::vector<float > hann2 (n_codes*n_fft);
250293
294+ const auto t_irfft_start = ggml_time_us ();
295+
251296 std::vector<std::thread> workers (n_thread);
252297 for (int i = 0 ; i < n_thread; ++i) {
253298 workers[i] = std::thread ([&, i]() {
254299 for (int l = i; l < n_codes; l += n_thread) {
255- irfft (n_fft, ST.data () + l*n_embd, res.data () + l*n_fft);
300+ irfft_2 (n_fft, ST.data () + l*n_embd, res.data () + l*n_fft);
256301 for (int j = 0 ; j < n_fft; ++j) {
257302 res [l*n_fft + j] *= hann[j];
258303 hann2[l*n_fft + j] = hann[j] * hann[j];
@@ -264,6 +309,8 @@ static std::vector<float> embd_to_audio(
264309 workers[i].join ();
265310 }
266311
312+ LOG_INF (" %s: time irfft: %.3f ms\n " , __func__, (ggml_time_us () - t_irfft_start) / 1000 .0f );
313+
267314 std::vector<float > audio;
268315 std::vector<float > env;
269316
0 commit comments