66#include < vector>
77#include < fstream>
88#include < float.h>
9+ #include < cstring> // memcpy and strcmp
10+ #include < inttypes.h>
11+
12+ // For more details on how this works, see: https://github.com/ggml-org/llama.cpp/pull/12648
913
1014static void print_usage (int , char ** argv) {
1115 LOG (" \n example usage:\n " );
@@ -30,6 +34,8 @@ static llama_token sample_greedy(const float * logits, int n_vocab) {
3034static bool ggml_callback (struct ggml_tensor * t, bool ask, void * user_data) {
3135 std::vector<float > * embd = (std::vector<float > *) user_data;
3236
37+ // output_csm_proj is the embeddings output from backbone
38+ // output_audio_embd is the embeddings output from decoder
3339 if (t && (strcmp (t->name , " output_csm_proj" ) == 0 || strcmp (t->name , " output_audio_embd" ) == 0 )) {
3440 if (ask) return true ;
3541
@@ -45,13 +51,10 @@ static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) {
4551int main (int argc, char ** argv) {
4652 common_params params;
4753
48- params.model = " sesame-csm-backbone.gguf" ;
49- params.out_file = " output.wav" ;
50- params.prompt = " [0]Hello from Sesame." ;
51-
52- params.n_predict = 4096 ;
53- params.n_batch = 8192 ;
54- params.n_ctx = 8192 ;
54+ params.model = " sesame-csm-backbone.gguf" ;
55+ params.out_file = " output.wav" ;
56+ params.prompt = " [0]Hello from Sesame." ;
57+ params.n_predict = 2048 ; // CSM's max trained seq length
5558
5659 if (!common_params_parse (argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) {
5760 return 1 ;
@@ -66,6 +69,7 @@ int main(int argc, char ** argv) {
6669 params.warmup = false ;
6770
6871 common_params params_decoder (params); // duplicate the params
72+ params_decoder.n_ctx = 64 ; // we never use more than this
6973 string_replace_all (params_decoder.model , " -backbone" , " -decoder" );
7074
7175 common_init_result llama_backbone = common_init_from_params (params);
@@ -96,77 +100,114 @@ int main(int argc, char ** argv) {
96100 printf (" \n " );
97101
98102 llama_pos n_past_bb = 0 ;
99- llama_batch batch = llama_batch_init (params.n_batch , 0 , 1 );
100- common_batch_clear (batch );
103+ llama_batch batch_prompt = llama_batch_init (params.n_batch , 0 , 1 );
104+ common_batch_clear (batch_prompt );
101105 for (size_t i = 0 ; i < prompt_tokens.size (); ++i) {
102- common_batch_add (batch , prompt_tokens[i], n_past_bb++, { 0 }, false );
106+ common_batch_add (batch_prompt , prompt_tokens[i], n_past_bb++, { 0 }, false );
103107 }
104- batch .logits [batch .n_tokens - 1 ] = true ;
108+ batch_prompt .logits [batch_prompt .n_tokens - 1 ] = true ;
105109
110+ // inp_past_embd is the "squashed" embeddings from the decoder
106111 std::vector<float > inp_past_embd (2048 , 0 .0f );
107112 llama_batch batch_past_embd = llama_batch_init (1 , inp_past_embd.size (), 1 );
108113
109- for (int k = 0 ; k < 32 ; ++k) {
110- if (llama_decode (ctx_bb, k == 0 ? batch : batch_past_embd) != 0 ) {
111- LOG_ERR (" %s: llama_decode() failed\n " , __func__);
114+ int64_t t_gb_start = ggml_time_ms (); // global start time
115+ int64_t t_bb = 0 ; // backbone time
116+ int64_t n_bb_gen = 0 ; // backbone generation count
117+ int64_t t_dc = 0 ; // decoder time
118+ int64_t n_dc_gen = 0 ; // decoder generation count
119+
120+ bool is_stop = false ;
121+
122+ // backbone generation loop
123+ for (int k = 0 ; k < params.n_predict ; ++k) {
124+ bool is_prompt_processing = k == 0 ;
125+
126+ if (!is_prompt_processing) {
127+ // generate the next RVQ semantic token
128+ batch_past_embd.n_tokens = 1 ;
129+ batch_past_embd.pos [0 ] = n_past_bb++;
130+ batch_past_embd.seq_id [0 ][0 ] = 0 ;
131+ batch_past_embd.n_seq_id [0 ] = 1 ;
132+ batch_past_embd.logits [0 ] = true ;
133+ std::memcpy (batch_past_embd.embd , inp_past_embd.data (), inp_past_embd.size () * sizeof (float ));
134+ }
135+
136+ int64_t t_bb_start = ggml_time_ms ();
137+ if (llama_decode (ctx_bb, is_prompt_processing ? batch_prompt : batch_past_embd) != 0 ) {
138+ LOG_ERR (" %s: backbone llama_decode() failed\n " , __func__);
112139 return 1 ;
113140 }
141+ n_bb_gen++;
142+ t_bb += ggml_time_ms () - t_bb_start;
114143
115144 auto vocab_dc = llama_model_get_vocab (model_dc);
116- auto logits = llama_get_logits_ith (ctx_bb, k == 0 ? (batch .n_tokens - 1 ) : 0 );
145+ auto logits = llama_get_logits_ith (ctx_bb, is_prompt_processing ? (batch_prompt .n_tokens - 1 ) : 0 );
117146 // for (size_t i = 0; i < 10; ++i) {
118147 // printf("%4.2f, ", logits[i]);
119148 // }
120149 // printf("\n");
121150
122- llama_token latent_token = sample_greedy (logits, llama_vocab_n_tokens (vocab_dc));
123- // printf("latent_token: %d\n", latent_token);
124- printf (" %d," , latent_token);
151+ llama_token semantic_tok = sample_greedy (logits, llama_vocab_n_tokens (vocab_dc));
152+ printf (" %d," , semantic_tok);
125153
126154 // for (size_t i = 0; i < 10; ++i) {
127155 // printf("%4.2f, ", embd[i]);
128156 // }
129157 // printf("\n");
130158
131-
132159
133- // decode
134- prompt_tokens.clear ();
135- prompt_tokens.push_back (latent_token);
160+ // decoder generation loop
136161 inp_past_embd = std::vector<float >(inp_past_embd.size (), 0 .0f );
137162 {
138163 llama_kv_self_clear (ctx_dc);
139164 llama_batch batch_embd = llama_batch_init (1 , embd.size (), 1 );
140165 llama_batch batch_token = llama_batch_init (1 , 0 , 1 );
166+
167+ // first "token" is the latent embeddings from backbone
141168 {
142169 batch_embd.n_tokens = 1 ;
143170 batch_embd.pos [0 ] = 0 ;
144171 batch_embd.seq_id [0 ][0 ] = 0 ;
145172 batch_embd.n_seq_id [0 ] = 1 ;
146173 batch_embd.logits [0 ] = false ;
147- memcpy (batch_embd.embd , embd.data (), embd.size () * sizeof (float ));
174+ std::memcpy (batch_embd.embd , embd.data (), embd.size () * sizeof (float ));
175+ }
176+ if (llama_decode (ctx_dc, batch_embd) != 0 ) {
177+ LOG_ERR (" %s: decoder llama_decode(embd) failed\n " , __func__);
178+ return 1 ;
148179 }
149- llama_decode (ctx_dc, batch_embd);
150-
151- llama_token audio_token = latent_token ;
180+
181+ // then, decode the semantic_tok to generate acoustic tokens
182+ llama_token tok = semantic_tok ;
152183 int n_codes = 32 ;
153- int sum_codes = 0 ;
184+ int sum_codes = 0 ; // to check if all codes are 0
154185 for (int i = 0 ; i < n_codes; ++i) {
155186 common_batch_clear (batch_token);
156187 // encoder vocab is further divided into 32 codebooks, each with 2051 entries
157- llama_token inp_tok = audio_token + 2051 *i;
188+ llama_token inp_tok = tok + 2051 *i;
158189 common_batch_add (batch_token, inp_tok, i+1 , { 0 }, true );
159- llama_decode (ctx_dc, batch_token);
190+
191+ int64_t t_bb_start = ggml_time_ms ();
192+ if (llama_decode (ctx_dc, batch_token) != 0 ) {
193+ LOG_ERR (" %s: decoder llama_decode(token) failed\n " , __func__);
194+ return 1 ;
195+ }
196+ n_dc_gen++;
197+ t_dc += ggml_time_ms () - t_bb_start;
198+
199+ // sample the acoustic token
160200 auto logits = llama_get_logits_ith (ctx_dc, 0 );
161- audio_token = sample_greedy (logits, llama_vocab_n_tokens (vocab_dc));
201+ llama_token acoustic_tok = sample_greedy (logits, llama_vocab_n_tokens (vocab_dc));
162202
163- // discard last code
203+ // discard last code (only for embeddings)
164204 if (i < n_codes - 1 ) {
165- printf (" %d," , audio_token );
166- prompt_tokens. push_back (audio_token);
167- sum_codes += audio_token ;
205+ printf (" %d," , acoustic_tok );
206+ tok = acoustic_tok; // next input token
207+ sum_codes += acoustic_tok ;
168208 }
169209
210+ // do progressive hsum of embeddings
170211 GGML_ASSERT (inp_past_embd.size () == embd.size ());
171212 for (size_t i = 0 ; i < inp_past_embd.size (); ++i) {
172213 inp_past_embd[i] += embd[i];
@@ -177,9 +218,8 @@ int main(int argc, char ** argv) {
177218 llama_batch_free (batch_embd);
178219 llama_batch_free (batch_token);
179220
180- if (sum_codes == 0 ) {
181- return 0 ; // done
182- }
221+ // if all codes are 0, then we are done
222+ is_stop = sum_codes == 0 ;
183223 }
184224
185225 // printf("inp_past_embd, n_past_bb = %d\n", n_past_bb);
@@ -192,17 +232,19 @@ int main(int argc, char ** argv) {
192232 // }
193233 // printf("\n");
194234
195- // prepare for the next iteration
196- {
197- batch_past_embd.n_tokens = 1 ;
198- batch_past_embd.pos [0 ] = n_past_bb;
199- batch_past_embd.seq_id [0 ][0 ] = 0 ;
200- batch_past_embd.n_seq_id [0 ] = 1 ;
201- batch_past_embd.logits [0 ] = true ;
202- memcpy (batch_past_embd.embd , inp_past_embd.data (), inp_past_embd.size () * sizeof (float ));
235+ if (is_stop) {
236+ break ;
203237 }
204- n_past_bb++;
205238 }
206239
240+ // print timing info
241+ printf (" \n timings:\n " );
242+ printf (" backbone: %" PRId64 " ms, %" PRId64 " generated token (%.2f tok/s)\n " , t_bb, n_bb_gen, (float )n_bb_gen*1000 /(float )t_bb);
243+ printf (" decoder: %" PRId64 " ms, %" PRId64 " generated token (%.2f tok/s)\n " , t_dc, n_dc_gen, (float )n_dc_gen*1000 /(float )t_dc);
244+ printf (" total: %" PRId64 " ms\n\n " , ggml_time_ms () - t_gb_start);
245+
246+ llama_batch_free (batch_prompt);
247+ llama_batch_free (batch_past_embd);
248+
207249 return 0 ;
208250}
0 commit comments