@@ -63,7 +63,47 @@ static void print_usage(int, char ** argv) {
6363 LOG (" \n " );
6464}
6565
66- static void fill_hann_window (int length, bool periodic, double * output) {
66+ struct wav_header {
67+ char riff[4 ] = {' R' , ' I' , ' F' , ' F' };
68+ uint32_t chunk_size;
69+ char wave[4 ] = {' W' , ' A' , ' V' , ' E' };
70+ char fmt[4 ] = {' f' , ' m' , ' t' , ' ' };
71+ uint32_t fmt_chunk_size = 16 ;
72+ uint16_t audio_format = 1 ; // PCM
73+ uint16_t num_channels = 1 ; // Mono
74+ uint32_t sample_rate;
75+ uint32_t byte_rate;
76+ uint16_t block_align;
77+ uint16_t bits_per_sample = 16 ;
78+ char data[4 ] = {' d' , ' a' , ' t' , ' a' };
79+ uint32_t data_size;
80+ };
81+
82+ static void save_wav16 (const std::string & fname, const std::vector<float > & data, int sample_rate) {
83+ std::ofstream file (fname, std::ios::binary);
84+ if (!file) {
85+ LOG_ERR (" %s: Failed to open file '%s' for writing" , __func__, fname.c_str ());
86+ return ;
87+ }
88+
89+ wav_header header;
90+ header.sample_rate = sample_rate;
91+ header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8 );
92+ header.block_align = header.num_channels * (header.bits_per_sample / 8 );
93+ header.data_size = data.size () * (header.bits_per_sample / 8 );
94+ header.chunk_size = 36 + header.data_size ;
95+
96+ file.write (reinterpret_cast <const char *>(&header), sizeof (header));
97+
98+ for (const auto & sample : data) {
99+ int16_t pcm_sample = static_cast <int16_t >(std::clamp (sample * 32767.0 , -32768.0 , 32767.0 ));
100+ file.write (reinterpret_cast <const char *>(&pcm_sample), sizeof (pcm_sample));
101+ }
102+
103+ file.close ();
104+ }
105+
106+ static void fill_hann_window (int length, bool periodic, float * output) {
67107 int offset = -1 ;
68108 if (periodic) {
69109 offset = 0 ;
@@ -74,31 +114,31 @@ static void fill_hann_window(int length, bool periodic, double * output) {
74114}
75115
76116// very poor-man fft
77- static void twiddle (double * real, double * imag, int k, int N) {
78- double angle = 2 * M_PI * k / N;
117+ static void twiddle (float * real, float * imag, int k, int N) {
118+ float angle = 2 * M_PI * k / N;
79119 *real = cos (angle);
80120 *imag = sin (angle);
81121}
82122
83- static void irfft (int n, const double * inp_cplx, double * out_real) {
123+ static void irfft (int n, const float * inp_cplx, float * out_real) {
84124 int N = n / 2 + 1 ;
85125
86- std::vector<double > real_input (N);
87- std::vector<double > imag_input (N);
126+ std::vector<float > real_input (N);
127+ std::vector<float > imag_input (N);
88128 for (int i = 0 ; i < N; ++i) {
89129 real_input[i] = inp_cplx[2 * i];
90130 imag_input[i] = inp_cplx[2 * i + 1 ];
91131 }
92132
93- std::vector<double > real_output (n);
94- std::vector<double > imag_output (n);
133+ std::vector<float > real_output (n);
134+ std::vector<float > imag_output (n);
95135
96136 for (int k = 0 ; k < n; ++k) {
97137 real_output[k] = 0 .0f ;
98138 imag_output[k] = 0 .0f ;
99139 for (int m = 0 ; m < N; ++m) {
100- double twiddle_real;
101- double twiddle_imag;
140+ float twiddle_real;
141+ float twiddle_imag;
102142
103143 twiddle (&twiddle_real, &twiddle_imag, k * m, n);
104144
@@ -123,7 +163,7 @@ static void irfft(int n, const double * inp_cplx, double * out_real) {
123163// hop_length = 320
124164// pad = 480
125165//
126- static void fold (const std::vector<double > & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector<double > & output) {
166+ static void fold (const std::vector<float > & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector<float > & output) {
127167 int64_t output_height = n_out;
128168 int64_t kernel_w = n_win;
129169 int64_t stride_w = n_hop;
@@ -147,103 +187,63 @@ static void fold(const std::vector<double> & data, int64_t n_out, int64_t n_win,
147187 output.resize (n_out - 2 * n_pad);
148188}
149189
150- struct wav_header {
151- char riff[4 ] = {' R' , ' I' , ' F' , ' F' };
152- uint32_t chunk_size;
153- char wave[4 ] = {' W' , ' A' , ' V' , ' E' };
154- char fmt[4 ] = {' f' , ' m' , ' t' , ' ' };
155- uint32_t fmt_chunk_size = 16 ;
156- uint16_t audio_format = 1 ; // PCM
157- uint16_t num_channels = 1 ; // Mono
158- uint32_t sample_rate;
159- uint32_t byte_rate;
160- uint16_t block_align;
161- uint16_t bits_per_sample = 16 ;
162- char data[4 ] = {' d' , ' a' , ' t' , ' a' };
163- uint32_t data_size;
164- };
165-
166- static void save_wav16 (const std::string & fname, const std::vector<double > & data, int sample_rate) {
167- std::ofstream file (fname, std::ios::binary);
168- if (!file) {
169- LOG_ERR (" %s: Failed to open file '%s' for writing" , __func__, fname.c_str ());
170- return ;
171- }
172-
173- wav_header header;
174- header.sample_rate = sample_rate;
175- header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8 );
176- header.block_align = header.num_channels * (header.bits_per_sample / 8 );
177- header.data_size = data.size () * (header.bits_per_sample / 8 );
178- header.chunk_size = 36 + header.data_size ;
179-
180- file.write (reinterpret_cast <const char *>(&header), sizeof (header));
181-
182- for (const auto & sample : data) {
183- int16_t pcm_sample = static_cast <int16_t >(std::clamp (sample * 32767.0 , -32768.0 , 32767.0 ));
184- file.write (reinterpret_cast <const char *>(&pcm_sample), sizeof (pcm_sample));
185- }
186-
187- file.close ();
188- }
189-
190- static std::vector<double > embd_to_audio (
190+ // TODO: not optimized at all
191+ static std::vector<float > embd_to_audio (
191192 const float * embd,
192- const std::vector<llama_token> & codes ,
193+ const int n_codes ,
193194 const int n_embd,
194195 const int n_thread) {
195- const int n = codes.size ();
196196 const int n_fft = 1280 ;
197197 const int n_hop = 320 ;
198198 const int n_win = 1280 ;
199199 const int n_pad = (n_win - n_hop)/2 ;
200- const int n_out = (n - 1 )*n_hop + n_win;
200+ const int n_out = (n_codes - 1 )*n_hop + n_win;
201201
202- std::vector<double > hann (n_fft);
202+ std::vector<float > hann (n_fft);
203203
204204 fill_hann_window (hann.size (), true , hann.data ());
205205
206- int n_spec = n_embd*n ;
206+ int n_spec = n_embd*n_codes ;
207207
208- std::vector<double > E (n_spec);
209- std::vector<double > S (n_spec);
210- std::vector<double > ST (n_spec);
208+ std::vector<float > E (n_spec);
209+ std::vector<float > S (n_spec);
210+ std::vector<float > ST (n_spec);
211211
212- for (int l = 0 ; l < n ; ++l) {
212+ for (int l = 0 ; l < n_codes ; ++l) {
213213 for (int k = 0 ; k < n_embd; ++k) {
214- E[k*n + l] = embd[l*n_embd + k];
214+ E[k*n_codes + l] = embd[l*n_embd + k];
215215 }
216216 }
217217
218218 for (int k = 0 ; k < n_embd/2 ; ++k) {
219- for (int l = 0 ; l < n ; ++l) {
220- double mag = E[(k )*n + l];
221- double phi = E[(k + n_embd/2 )*n + l];
219+ for (int l = 0 ; l < n_codes ; ++l) {
220+ float mag = E[(k )*n_codes + l];
221+ float phi = E[(k + n_embd/2 )*n_codes + l];
222222
223223 mag = exp (mag);
224224
225225 if (mag > 1e2 ) {
226226 mag = 1e2 ;
227227 }
228- S[2 *(k*n + l) + 0 ] = mag*cosf (phi);
229- S[2 *(k*n + l) + 1 ] = mag*sinf (phi);
228+ S[2 *(k*n_codes + l) + 0 ] = mag*cosf (phi);
229+ S[2 *(k*n_codes + l) + 1 ] = mag*sinf (phi);
230230 }
231231 }
232232
233- for (int l = 0 ; l < n ; ++l) {
233+ for (int l = 0 ; l < n_codes ; ++l) {
234234 for (int k = 0 ; k < n_embd/2 ; ++k) {
235- ST[l*n_embd + 2 *k + 0 ] = S[2 *(k*n + l) + 0 ];
236- ST[l*n_embd + 2 *k + 1 ] = S[2 *(k*n + l) + 1 ];
235+ ST[l*n_embd + 2 *k + 0 ] = S[2 *(k*n_codes + l) + 0 ];
236+ ST[l*n_embd + 2 *k + 1 ] = S[2 *(k*n_codes + l) + 1 ];
237237 }
238238 }
239239
240- std::vector<double > res (n *n_fft);
241- std::vector<double > hann2 (n *n_fft);
240+ std::vector<float > res (n_codes *n_fft);
241+ std::vector<float > hann2 (n_codes *n_fft);
242242
243243 std::vector<std::thread> workers (n_thread);
244244 for (int i = 0 ; i < n_thread; ++i) {
245245 workers[i] = std::thread ([&, i]() {
246- for (int l = i; l < n ; l += n_thread) {
246+ for (int l = i; l < n_codes ; l += n_thread) {
247247 irfft (n_fft, ST.data () + l*n_embd, res.data () + l*n_fft);
248248 for (int j = 0 ; j < n_fft; ++j) {
249249 res [l*n_fft + j] *= hann[j];
@@ -256,8 +256,8 @@ static std::vector<double> embd_to_audio(
256256 workers[i].join ();
257257 }
258258
259- std::vector<double > audio;
260- std::vector<double > env;
259+ std::vector<float > audio;
260+ std::vector<float > env;
261261
262262 fold (res, n_out, n_win, n_hop, n_pad, audio);
263263 fold (hann2, n_out, n_win, n_hop, n_pad, env); // TODO: can be done once
@@ -844,12 +844,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
844844
845845 const auto t_voc_start = ggml_time_us ();
846846
847- llama_batch batch = llama_batch_init (codes.size (), 0 , 1 );
847+ const int n_codes = codes.size ();
848+
849+ llama_batch batch = llama_batch_init (n_codes, 0 , 1 );
848850
849851 for (size_t i = 0 ; i < codes.size (); ++i) {
850852 common_batch_add (batch, codes[i], i, { 0 }, true ); // TODO: all logits?
851853 }
852- GGML_ASSERT (batch.n_tokens == ( int ) codes. size () );
854+ GGML_ASSERT (batch.n_tokens == n_codes );
853855
854856 if (llama_decode (ctx_cts, batch) != 0 ) {
855857 LOG_ERR (" %s: llama_decode() failed\n " , __func__);
@@ -862,12 +864,40 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
862864
863865 const auto t_spec_start = ggml_time_us ();
864866
867+ #if 1
865868 // spectral operations
866- // TODO: not optimized at all
867869 const int n_embd = llama_n_embd (model_cts);
868870 const float * embd = llama_get_embeddings (ctx_cts);
869871
870- auto audio = embd_to_audio (embd, codes, n_embd, params.cpuparams .n_threads );
872+ auto audio = embd_to_audio (embd, n_codes, n_embd, params.cpuparams .n_threads );
873+
874+ #else
875+ // read the spectrogram from a file for debugging purposes
876+ std::vector<float> audio;
877+ {
878+ std::ifstream fin("out.bin", std::ios::binary);
879+ if (!fin) {
880+ LOG_ERR("%s: failed to open file '%s'\n", __func__, "out.bin");
881+ return 1;
882+ }
883+
884+ std::vector<float> embd;
885+
886+ int n_codes;
887+ int n_embd;
888+
889+ fin.read(reinterpret_cast<char *>(&n_codes), sizeof(int));
890+ fin.read(reinterpret_cast<char *>(&n_embd), sizeof(int));
891+
892+ embd.resize(n_codes * n_embd);
893+ fin.read(reinterpret_cast<char *>(embd.data()), n_codes * n_embd * sizeof(float));
894+ fin.close();
895+
896+ LOG_INF("%s: n_codes: %d, n_embd: %d\n", __func__, n_codes, n_embd);
897+
898+ audio = embd_to_audio(embd.data(), n_codes, n_embd, params.cpuparams.n_threads);
899+ }
900+ #endif
871901
872902 const std::string fname = " output.wav" ;
873903
0 commit comments