22#include " common.h"
33#include " log.h"
44#include " arg.h"
5+ #include " mimi-model.h"
56
67#include < vector>
78#include < fstream>
1314
1415static void print_usage (int , char ** argv) {
1516 LOG (" \n example usage:\n " );
16- LOG (" \n %s TODO " , argv[0 ]);
17+ LOG (" \n By default, model will be downloaded from https://huggingface.co/ggml-org/sesame-csm-1b-GGUF" );
18+ LOG (" \n %s -p \" [0]I have a dream that one day every valley shall be exalted\" -o output.wav" , argv[0 ]);
19+ LOG (" \n " );
20+ LOG (" \n To use a local model, specify the path to the model file:" );
21+ LOG (" \n %s -p ... -m sesame-csm-backbone.gguf -mv kyutai-mimi.gguf -o output.wav" , argv[0 ]);
22+ LOG (" \n " );
23+ LOG (" \n Note: the model need 2 files to run, one ends with '-backbone-<quant>.gguf' and the other ends with '-decoder<quant>.gguf'" );
1724 LOG (" \n " );
1825}
1926
@@ -51,10 +58,15 @@ static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) {
5158int main (int argc, char ** argv) {
5259 common_params params;
5360
54- params.model = " sesame-csm-backbone.gguf" ;
55- params.out_file = " output.wav" ;
56- params.prompt = " [0]Hello from Sesame." ;
57- params.n_predict = 2048 ; // CSM's max trained seq length
61+ params.model = " sesame-csm-backbone.gguf" ;
62+ params.vocoder .model = " kyutai-mimi.gguf" ;
63+ params.out_file = " output.wav" ;
64+ params.prompt = " [0]Hello from Sesame." ;
65+ params.n_predict = 2048 ; // CSM's max trained seq length
66+
67+ // HF model
68+ params.model_url = " https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-backbone.gguf" ;
69+ params.vocoder .model_url = " https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/kyutai-mimi.gguf" ;
5870
5971 if (!common_params_parse (argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) {
6072 return 1 ;
@@ -71,6 +83,9 @@ int main(int argc, char ** argv) {
7183 common_params params_decoder (params); // duplicate the params
7284 params_decoder.n_ctx = 64 ; // we never use more than this
7385 string_replace_all (params_decoder.model , " -backbone" , " -decoder" );
86+ if (!params_decoder.model_url .empty ()) {
87+ string_replace_all (params_decoder.model_url , " -backbone" , " -decoder" );
88+ }
7489
7590 common_init_result llama_backbone = common_init_from_params (params);
7691 llama_model * model_bb = llama_backbone.model .get ();
@@ -88,6 +103,8 @@ int main(int argc, char ** argv) {
88103 return ENOENT;
89104 }
90105
106+ mimi_model mimi (params.vocoder .model .c_str (), true );
107+
91108 const llama_vocab * vocab = llama_model_get_vocab (model_bb);
92109 llama_tokens prompt_tokens = common_tokenize (vocab, params.prompt , false , true );
93110 prompt_tokens.insert (prompt_tokens.begin (), llama_vocab_bos (vocab));
@@ -118,6 +135,7 @@ int main(int argc, char ** argv) {
118135 int64_t n_dc_gen = 0 ; // decoder generation count
119136
120137 bool is_stop = false ;
138+ std::vector<int > generated_codes;
121139
122140 // backbone generation loop
123141 for (int k = 0 ; k < params.n_predict ; ++k) {
@@ -150,6 +168,7 @@ int main(int argc, char ** argv) {
150168
151169 llama_token semantic_tok = sample_greedy (logits, llama_vocab_n_tokens (vocab_dc));
152170 printf (" %d," , semantic_tok);
171+ generated_codes.push_back (semantic_tok);
153172
154173 // for (size_t i = 0; i < 10; ++i) {
155174 // printf("%4.2f, ", embd[i]);
@@ -205,6 +224,7 @@ int main(int argc, char ** argv) {
205224 printf (" %d," , acoustic_tok);
206225 tok = acoustic_tok; // next input token
207226 sum_codes += acoustic_tok;
227+ generated_codes.push_back (acoustic_tok);
208228 }
209229
210230 // do progressive hsum of embeddings
@@ -246,5 +266,16 @@ int main(int argc, char ** argv) {
246266 llama_batch_free (batch_prompt);
247267 llama_batch_free (batch_past_embd);
248268
269+ printf (" decode %zu RVQ tokens into wav...\n " , generated_codes.size ());
270+ generated_codes = mimi.transpose_input (generated_codes);
271+ std::vector<float > wav_data = mimi.decode (generated_codes);
272+
273+ if (!save_wav16 (params.out_file .c_str (), wav_data, mimi.get_sample_rate ())) {
274+ LOG_ERR (" Failed to save wav file\n " );
275+ return 1 ;
276+ }
277+
278+ printf (" \n " );
279+
249280 return 0 ;
250281}
0 commit comments