55#include " mimi-model.h"
66
77#include < vector>
8+ #include < regex>
89#include < fstream>
910#include < float.h>
1011#include < cstring> // memcpy and strcmp
@@ -23,12 +24,39 @@ static void print_usage(int, char ** argv) {
2324 LOG (" \n Note: the model need 2 files to run, one ends with '-backbone-<quant>.gguf' and the other ends with '-decoder<quant>.gguf'" );
2425 LOG (" \n " );
2526 LOG (" \n Prompt format:" );
26- LOG (" \n Each line must start with speaker ID in square brackets, followed by the text. A full stop is recommended at the end of each turn" );
27- LOG (" \n Example: [0]Hello world." );
27+ LOG (" \n Each line must start with speaker ID in square brackets, followed by the text. One turn per line. A full stop is recommended at the end of each turn" );
28+ LOG (" \n Example:" );
29+ LOG (" \n [0]Hey how are you doing." );
30+ LOG (" \n [1]Pretty good, pretty good." );
2831 LOG (" \n If you want to enter long text, use -f file.txt to read from file" );
2932 LOG (" \n " );
3033}
3134
35+ // split text containing "[N]..." into speaker turns
36+ static std::vector<std::string> get_speaker_turns (const std::string & input) {
37+ if (input.empty ()) {
38+ LOG_ERR (" Empty input\n " );
39+ return {};
40+ }
41+ if (input[0 ] != ' [' ) {
42+ LOG_ERR (" Invalid input format: missing speaker ID\n " );
43+ return {};
44+ }
45+ std::regex re (R"( (\[\d+\][\s\S]*?)(?=\[\d+\]|$))" );
46+ std::smatch match;
47+ std::vector<std::string> turns;
48+ std::string::const_iterator searchStart (input.cbegin ());
49+ while (std::regex_search (searchStart, input.cend (), match, re)) {
50+ std::string turn = match[1 ].str ();
51+ if (turn.empty ()) {
52+ continue ;
53+ }
54+ turns.push_back (turn);
55+ searchStart = match.suffix ().first ;
56+ }
57+ return turns;
58+ }
59+
3260// sampling with custom n_vocab
3361// modified version of llama_sampler_sample()
3462static llama_token sample_token (struct llama_sampler * smpl, const float * logits, int n_vocab) {
@@ -81,9 +109,11 @@ int main(int argc, char ** argv) {
81109 params.sampling .top_k = 50 ; // default param from CSM python code
82110 params.sampling .temp = 0.9 ; // default param from CSM python code
83111
84- // HF model
85- params.model .url = " https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-backbone.gguf" ;
86- params.vocoder .model .url = " https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/kyutai-mimi.gguf" ;
112+ // HF model (hack: we temporary reuse speculative.model as the decoder model, only to get it downloaded)
113+ params.model .url = " https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-backbone.gguf" ;
114+ params.speculative .model .path = " sesame-csm-decoder.gguf" ;
115+ params.speculative .model .url = " https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-decoder.gguf" ;
116+ params.vocoder .model .url = " https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/kyutai-mimi.gguf" ;
87117
88118 if (!common_params_parse (argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) {
89119 return 1 ;
@@ -125,32 +155,15 @@ int main(int argc, char ** argv) {
125155
126156 mimi_model mimi (params.vocoder .model .path .c_str (), true );
127157
128- // tokenize the prompt
129- const llama_vocab * vocab = llama_model_get_vocab (model_bb);
130- llama_tokens prompt_tokens = common_tokenize (vocab, params.prompt , false , true );
131- prompt_tokens.insert (prompt_tokens.begin (), llama_vocab_bos (vocab));
132- prompt_tokens.insert (prompt_tokens.end (), llama_vocab_eos (vocab));
133-
134158 // init sampler
135159 // the python implementation only has top-k and temperature sampling, so we'll use just that
136160 llama_sampler_ptr sampler (llama_sampler_chain_init (llama_sampler_chain_default_params ()));
137161 llama_sampler_chain_add (sampler.get (), llama_sampler_init_top_k (params.sampling .top_k ));
138162 llama_sampler_chain_add (sampler.get (), llama_sampler_init_temp (params.sampling .temp ));
139163 llama_sampler_chain_add (sampler.get (), llama_sampler_init_dist (params.sampling .seed ));
140164
141- printf (" prompt tokens: \n " );
142- for (size_t i = 0 ; i < prompt_tokens.size (); ++i) {
143- printf (" %d, " , prompt_tokens[i]);
144- }
145- printf (" \n " );
146-
147- llama_pos n_past_bb = 0 ;
148165 llama_batch batch_prompt = llama_batch_init (params.n_batch , 0 , 1 );
149- common_batch_clear (batch_prompt);
150- for (size_t i = 0 ; i < prompt_tokens.size (); ++i) {
151- common_batch_add (batch_prompt, prompt_tokens[i], n_past_bb++, { 0 }, false );
152- }
153- batch_prompt.logits [batch_prompt.n_tokens - 1 ] = true ;
166+ llama_pos n_past_bb = 0 ;
154167
155168 // inp_past_embd is the "squashed" embeddings from the decoder
156169 std::vector<float > inp_past_embd (2048 , 0 .0f );
@@ -162,128 +175,154 @@ int main(int argc, char ** argv) {
162175 int64_t t_dc = 0 ; // decoder time
163176 int64_t n_dc_gen = 0 ; // decoder generation count
164177
165- bool is_stop = false ;
166178 std::vector<int > generated_codes;
167179
168- // backbone generation loop
169- for (int k = 0 ; k < params.n_predict ; ++k) {
170- bool is_prompt_processing = k == 0 ;
171-
172- if (!is_prompt_processing) {
173- // generate the next RVQ semantic token
174- batch_past_embd.n_tokens = 1 ;
175- batch_past_embd.pos [0 ] = n_past_bb++;
176- batch_past_embd.seq_id [0 ][0 ] = 0 ;
177- batch_past_embd.n_seq_id [0 ] = 1 ;
178- batch_past_embd.logits [0 ] = true ;
179- std::memcpy (batch_past_embd.embd , inp_past_embd.data (), inp_past_embd.size () * sizeof (float ));
180- }
180+ auto turns = get_speaker_turns (params.prompt );
181+
182+ for (const std::string & turn : turns) {
183+ // tokenize the turn
184+ llama_tokens prompt_tokens;
185+ {
186+ printf (" \n ---\n turn: %s\n\n " , turn.c_str ());
187+ const llama_vocab * vocab = llama_model_get_vocab (model_bb);
188+ prompt_tokens = common_tokenize (vocab, turn, false , true );
189+ prompt_tokens.insert (prompt_tokens.begin (), llama_vocab_bos (vocab));
190+ prompt_tokens.insert (prompt_tokens.end (), llama_vocab_eos (vocab));
191+
192+ printf (" prompt (%zu tokens): \n " , prompt_tokens.size ());
193+ for (size_t i = 0 ; i < prompt_tokens.size (); ++i) {
194+ printf (" %d, " , prompt_tokens[i]);
195+ }
196+ printf (" \n " );
181197
182- int64_t t_bb_start = ggml_time_ms ();
183- if (llama_decode (ctx_bb, is_prompt_processing ? batch_prompt : batch_past_embd) != 0 ) {
184- LOG_ERR (" %s: backbone llama_decode() failed\n " , __func__);
185- return 1 ;
198+ common_batch_clear (batch_prompt);
199+ for (size_t i = 0 ; i < prompt_tokens.size (); ++i) {
200+ common_batch_add (batch_prompt, prompt_tokens[i], n_past_bb++, { 0 }, false );
201+ }
202+ batch_prompt.logits [batch_prompt.n_tokens - 1 ] = true ;
186203 }
187- n_bb_gen++;
188- t_bb += ggml_time_ms () - t_bb_start;
189204
190- auto vocab_dc = llama_model_get_vocab (model_dc);
191- auto logits = llama_get_logits_ith (ctx_bb, is_prompt_processing ? (batch_prompt.n_tokens - 1 ) : 0 );
192- // for (size_t i = 0; i < 10; ++i) {
193- // printf("%4.2f, ", logits[i]);
194- // }
195- // printf("\n");
205+ // backbone generation loop
206+ bool is_end_of_turn = false ;
207+ for (int k = 0 ; k < params.n_predict ; ++k) {
208+ bool is_prompt_processing = k == 0 ;
209+
210+ if (!is_prompt_processing) {
211+ // generate the next RVQ semantic token
212+ batch_past_embd.n_tokens = 1 ;
213+ batch_past_embd.pos [0 ] = n_past_bb++;
214+ batch_past_embd.seq_id [0 ][0 ] = 0 ;
215+ batch_past_embd.n_seq_id [0 ] = 1 ;
216+ batch_past_embd.logits [0 ] = true ;
217+ std::memcpy (batch_past_embd.embd , inp_past_embd.data (), inp_past_embd.size () * sizeof (float ));
218+ }
196219
197- llama_token semantic_tok = sample_token (sampler.get (), logits, llama_vocab_n_tokens (vocab_dc));
198- printf (" Sem token %5d : %d," , 1 +(int )generated_codes.size ()/32 , semantic_tok);
199- generated_codes.push_back (semantic_tok);
220+ int64_t t_bb_start = ggml_time_ms ();
221+ if (llama_decode (ctx_bb, is_prompt_processing ? batch_prompt : batch_past_embd) != 0 ) {
222+ LOG_ERR (" %s: backbone llama_decode() failed\n " , __func__);
223+ return 1 ;
224+ }
225+ n_bb_gen++;
226+ t_bb += ggml_time_ms () - t_bb_start;
200227
201- // for (size_t i = 0; i < 10; ++i) {
202- // printf("%4.2f, ", embd[i]);
203- // }
204- // printf("\n");
228+ auto vocab_dc = llama_model_get_vocab (model_dc);
229+ auto logits = llama_get_logits_ith (ctx_bb, is_prompt_processing ? (batch_prompt.n_tokens - 1 ) : 0 );
230+ // for (size_t i = 0; i < 10; ++i) {
231+ // printf("%4.2f, ", logits[i]);
232+ // }
233+ // printf("\n");
205234
235+ llama_token semantic_tok = sample_token (sampler.get (), logits, llama_vocab_n_tokens (vocab_dc));
236+ printf (" Sem token %5d : %d," , 1 +(int )generated_codes.size ()/32 , semantic_tok);
237+ generated_codes.push_back (semantic_tok);
206238
207- // decoder generation loop
208- inp_past_embd = std::vector<float >(inp_past_embd.size (), 0 .0f );
209- {
210- llama_kv_self_clear (ctx_dc);
211- llama_batch batch_embd = llama_batch_init (1 , embd.size (), 1 );
212- llama_batch batch_token = llama_batch_init (1 , 0 , 1 );
239+ // for (size_t i = 0; i < 10; ++i) {
240+ // printf("%4.2f, ", embd[i]);
241+ // }
242+ // printf("\n");
213243
214- // first "token" is the latent embeddings from backbone
215- {
216- batch_embd.n_tokens = 1 ;
217- batch_embd.pos [0 ] = 0 ;
218- batch_embd.seq_id [0 ][0 ] = 0 ;
219- batch_embd.n_seq_id [0 ] = 1 ;
220- batch_embd.logits [0 ] = false ;
221- std::memcpy (batch_embd.embd , embd.data (), embd.size () * sizeof (float ));
222- }
223- if (llama_decode (ctx_dc, batch_embd) != 0 ) {
224- LOG_ERR (" %s: decoder llama_decode(embd) failed\n " , __func__);
225- return 1 ;
226- }
227244
228- // then, decode the semantic_tok to generate acoustic tokens
229- llama_token tok = semantic_tok;
230- int n_codes = 32 ;
231- int sum_codes = semantic_tok; // to check if all codes are 0
232- for (int i = 0 ; i < n_codes; ++i) {
233- common_batch_clear (batch_token);
234- // encoder vocab is further divided into 32 codebooks, each with 2051 entries
235- llama_token inp_tok = tok + 2051 *i;
236- common_batch_add (batch_token, inp_tok, i+1 , { 0 }, true );
237-
238- int64_t t_bb_start = ggml_time_ms ();
239- if (llama_decode (ctx_dc, batch_token) != 0 ) {
240- LOG_ERR (" %s: decoder llama_decode(token) failed\n " , __func__);
241- return 1 ;
245+ // decoder generation loop
246+ inp_past_embd = std::vector<float >(inp_past_embd.size (), 0 .0f );
247+ {
248+ llama_kv_self_clear (ctx_dc);
249+ llama_batch batch_embd = llama_batch_init (1 , embd.size (), 1 );
250+ llama_batch batch_token = llama_batch_init (1 , 0 , 1 );
251+
252+ // first "token" is the latent embeddings from backbone
253+ {
254+ batch_embd.n_tokens = 1 ;
255+ batch_embd.pos [0 ] = 0 ;
256+ batch_embd.seq_id [0 ][0 ] = 0 ;
257+ batch_embd.n_seq_id [0 ] = 1 ;
258+ batch_embd.logits [0 ] = false ;
259+ std::memcpy (batch_embd.embd , embd.data (), embd.size () * sizeof (float ));
242260 }
243- n_dc_gen++;
244- t_dc += ggml_time_ms () - t_bb_start;
245-
246- // sample the acoustic token
247- auto logits = llama_get_logits_ith (ctx_dc, 0 );
248- llama_token acoustic_tok = sample_token (sampler.get (), logits, llama_vocab_n_tokens (vocab_dc));
249-
250- // discard last code (only for embeddings)
251- if (i < n_codes - 1 ) {
252- printf (" %d," , acoustic_tok);
253- tok = acoustic_tok; // next input token
254- sum_codes += acoustic_tok;
255- generated_codes.push_back (acoustic_tok);
261+ if (llama_decode (ctx_dc, batch_embd) != 0 ) {
262+ LOG_ERR (" %s: decoder llama_decode(embd) failed\n " , __func__);
263+ return 1 ;
256264 }
257265
258- // do progressive hsum of embeddings
259- GGML_ASSERT (inp_past_embd.size () == embd.size ());
260- for (size_t i = 0 ; i < inp_past_embd.size (); ++i) {
261- inp_past_embd[i] += embd[i];
266+ // then, decode the semantic_tok to generate acoustic tokens
267+ llama_token tok = semantic_tok;
268+ int n_codes = 32 ;
269+ int sum_codes = semantic_tok; // to check if all codes are 0
270+ for (int i = 0 ; i < n_codes; ++i) {
271+ common_batch_clear (batch_token);
272+ // encoder vocab is further divided into 32 codebooks, each with 2051 entries
273+ llama_token inp_tok = tok + 2051 *i;
274+ common_batch_add (batch_token, inp_tok, i+1 , { 0 }, true );
275+
276+ int64_t t_bb_start = ggml_time_ms ();
277+ if (llama_decode (ctx_dc, batch_token) != 0 ) {
278+ LOG_ERR (" %s: decoder llama_decode(token) failed\n " , __func__);
279+ return 1 ;
280+ }
281+ n_dc_gen++;
282+ t_dc += ggml_time_ms () - t_bb_start;
283+
284+ // sample the acoustic token
285+ auto logits = llama_get_logits_ith (ctx_dc, 0 );
286+ llama_token acoustic_tok = sample_token (sampler.get (), logits, llama_vocab_n_tokens (vocab_dc));
287+
288+ // discard last code (only for embeddings)
289+ if (i < n_codes - 1 ) {
290+ printf (" %d," , acoustic_tok);
291+ tok = acoustic_tok; // next input token
292+ sum_codes += acoustic_tok;
293+ generated_codes.push_back (acoustic_tok);
294+ }
295+
296+ // do progressive hsum of embeddings
297+ GGML_ASSERT (inp_past_embd.size () == embd.size ());
298+ for (size_t i = 0 ; i < inp_past_embd.size (); ++i) {
299+ inp_past_embd[i] += embd[i];
300+ }
262301 }
263- }
264- printf (" \n " );
302+ printf (" \n " );
265303
266- llama_batch_free (batch_embd);
267- llama_batch_free (batch_token);
304+ llama_batch_free (batch_embd);
305+ llama_batch_free (batch_token);
268306
269- // if all codes are 0, then we are done
270- is_stop = sum_codes == 0 ;
271- }
307+ // if all codes are 0, then we are done
308+ is_end_of_turn = sum_codes == 0 ;
309+ }
272310
273- // printf("inp_past_embd, n_past_bb = %d\n", n_past_bb);
274- // for (size_t i = 0; i < inp_past_embd.size(); ++i) {
275- // printf("%4.4f, ", inp_past_embd[i]);
276- // if (i == 2) {
277- // printf("... ");
278- // i = inp_past_embd.size() - 4;
279- // }
280- // }
281- // printf("\n");
282-
283- if (is_stop) {
284- // remove last 32 codes since they will be all zeros
285- generated_codes.resize (generated_codes.size () - 32 );
286- break ;
311+ // printf("inp_past_embd, n_past_bb = %d\n", n_past_bb);
312+ // for (size_t i = 0; i < inp_past_embd.size(); ++i) {
313+ // printf("%4.4f, ", inp_past_embd[i]);
314+ // if (i == 2) {
315+ // printf("... ");
316+ // i = inp_past_embd.size() - 4;
317+ // }
318+ // }
319+ // printf("\n");
320+
321+ if (is_end_of_turn) {
322+ // remove last 32 codes since they will be all zeros
323+ generated_codes.resize (generated_codes.size () - 32 );
324+ break ;
325+ }
287326 }
288327 }
289328
0 commit comments