@@ -30,13 +30,12 @@ static llama_token sample_greedy(const float * logits, int n_vocab) {
3030static bool ggml_callback (struct ggml_tensor * t, bool ask, void * user_data) {
3131 std::vector<float > * embd = (std::vector<float > *) user_data;
3232
33- if (t && strcmp (t->name , " result_norm " ) == 0 ) {
33+ if (t && ( strcmp (t->name , " output_csm_proj " ) == 0 || strcmp (t-> name , " output_audio_embd " ) == 0 ) ) {
3434 if (ask) return true ;
3535
36- auto n_bytes = ggml_nbytes (t);
37- embd->resize (n_bytes);
38- ggml_backend_tensor_get (t, embd->data (), 0 , n_bytes);
39- printf (" result_norm\n " );
36+ embd->resize (ggml_nelements (t));
37+ ggml_backend_tensor_get (t, embd->data (), 0 , ggml_nbytes (t));
38+ // printf("%s tensor size: %lld, %lld\n", t->name, t->ne[0], t->ne[1]);
4039 return true ;
4140 }
4241
@@ -54,34 +53,37 @@ int main(int argc, char ** argv) {
5453 params.n_batch = 8192 ;
5554 params.n_ctx = 8192 ;
5655
57- params.sampling .top_k = 4 ;
58- params.sampling .samplers = { COMMON_SAMPLER_TYPE_TOP_K, };
59-
6056 if (!common_params_parse (argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) {
6157 return 1 ;
6258 }
6359
6460 llama_backend_init ();
6561 llama_numa_init (params.numa );
6662
67- common_params params_decoder (params); // duplicate the params
68- string_replace_all (params_decoder.model , " -backbone" , " -decoder" );
69-
7063 std::vector<float > embd;
7164 params.cb_eval = ggml_callback;
7265 params.cb_eval_user_data = &embd;
66+ params.warmup = false ;
67+
68+ common_params params_decoder (params); // duplicate the params
69+ string_replace_all (params_decoder.model , " -backbone" , " -decoder" );
70+
7371 common_init_result llama_backbone = common_init_from_params (params);
7472 llama_model * model_bb = llama_backbone.model .get ();
7573 llama_context * ctx_bb = llama_backbone.context .get ();
7674
77- // common_init_result llama_decoder = common_init_from_params(params_decoder);
78- // llama_model * model_dc = llama_decoder.model.get();
79- // llama_context * ctx_dc = llama_decoder.context.get();
75+ common_init_result llama_decoder = common_init_from_params (params_decoder);
76+ llama_model * model_dc = llama_decoder.model .get ();
77+ llama_context * ctx_dc = llama_decoder.context .get ();
8078
8179 if (model_bb == nullptr || ctx_bb == nullptr ) {
8280 return ENOENT;
8381 }
8482
83+ if (model_dc == nullptr || ctx_dc == nullptr ) {
84+ return ENOENT;
85+ }
86+
8587 const llama_vocab * vocab = llama_model_get_vocab (model_bb);
8688 llama_tokens prompt_tokens = common_tokenize (vocab, params.prompt , false , true );
8789 prompt_tokens.insert (prompt_tokens.begin (), llama_vocab_bos (vocab));
@@ -93,27 +95,92 @@ int main(int argc, char ** argv) {
9395 }
9496 printf (" \n " );
9597
98+ llama_pos n_past_bb = 0 ;
9699 llama_batch batch = llama_batch_init (params.n_batch , 0 , 1 );
100+ common_batch_clear (batch);
97101 for (size_t i = 0 ; i < prompt_tokens.size (); ++i) {
98- common_batch_add (batch, prompt_tokens[i], i , { 0 }, false );
102+ common_batch_add (batch, prompt_tokens[i], n_past_bb++ , { 0 }, false );
99103 }
100104 batch.logits [batch.n_tokens - 1 ] = true ;
101105
102- if (llama_decode (ctx_bb, batch) != 0 ) {
103- LOG_ERR (" %s: llama_decode() failed\n " , __func__);
104- return 1 ;
105- }
106+ std::vector<float > inp_past_embd (2048 , 0 .0f );
107+ llama_batch batch_past_embd = llama_batch_init (1 , inp_past_embd.size (), 1 );
106108
107- // auto vocab_dc = llama_model_get_vocab(model_dc);
108- auto logits = llama_get_logits_ith (ctx_bb, batch.n_tokens - 1 );
109- // printf("next tok: %d\n", sample_greedy(logits, llama_vocab_n_tokens(vocab_dc)));
110- for (size_t i = 0 ; i < 10 ; ++i) {
111- printf (" %4.2f, " , logits[i]);
112- }
113- printf (" next tok: %d\n " , sample_greedy (logits, 65632 ));
109+ for (int k = 0 ; k < 4 ; ++k) {
110+ if (llama_decode (ctx_bb, k == 0 ? batch : batch_past_embd) != 0 ) {
111+ LOG_ERR (" %s: llama_decode() failed\n " , __func__);
112+ return 1 ;
113+ }
114+
115+ auto vocab_dc = llama_model_get_vocab (model_dc);
116+ auto logits = llama_get_logits_ith (ctx_bb, k == 0 ? (batch.n_tokens - 1 ) : 0 );
117+ // for (size_t i = 0; i < 10; ++i) {
118+ // printf("%4.2f, ", logits[i]);
119+ // }
120+ // printf("\n");
121+
122+ llama_token latent_token = sample_greedy (logits, llama_vocab_n_tokens (vocab_dc));
123+ // printf("latent_token: %d\n", latent_token);
124+ printf (" %5d, " , latent_token);
125+
126+ // for (size_t i = 0; i < 10; ++i) {
127+ // printf("%4.2f, ", embd[i]);
128+ // }
129+ // printf("\n");
130+
131+
132+
133+ // decode
134+ prompt_tokens.clear ();
135+ prompt_tokens.push_back (latent_token);
136+ inp_past_embd = std::vector<float >(inp_past_embd.size (), 0 .0f );
137+ {
138+ llama_kv_self_clear (ctx_dc);
139+ llama_batch batch_embd = llama_batch_init (1 , embd.size (), 1 );
140+ llama_batch batch_token = llama_batch_init (1 , 0 , 1 );
141+ {
142+ batch_embd.n_tokens = 1 ;
143+ batch_embd.pos [0 ] = 0 ;
144+ batch_embd.seq_id [0 ][0 ] = 0 ;
145+ batch_embd.n_seq_id [0 ] = 1 ;
146+ batch_embd.logits [0 ] = false ;
147+ memcpy (batch_embd.embd , embd.data (), embd.size () * sizeof (float ));
148+ }
149+ llama_decode (ctx_dc, batch_embd);
150+
151+ llama_token audio_token = latent_token;
152+ for (int i = 0 ; i < 31 ; ++i) {
153+ common_batch_clear (batch_token);
154+ // encoder vocab is further divided into 32 codebooks, each with 2051 entries
155+ llama_token inp_tok = audio_token + 2051 *i;
156+ common_batch_add (batch_token, inp_tok, i+1 , { 0 }, true );
157+ llama_decode (ctx_dc, batch_token);
158+ auto logits = llama_get_logits_ith (ctx_dc, 0 );
159+ audio_token = sample_greedy (logits, llama_vocab_n_tokens (vocab_dc));
160+ printf (" %d," , audio_token);
161+ prompt_tokens.push_back (audio_token);
162+
163+ GGML_ASSERT (inp_past_embd.size () == embd.size ());
164+ for (size_t i = 0 ; i < inp_past_embd.size (); ++i) {
165+ inp_past_embd[i] += embd[i];
166+ }
167+ }
168+ printf (" \n " );
169+
170+ llama_batch_free (batch_embd);
171+ llama_batch_free (batch_token);
172+ }
114173
115- for (size_t i = 0 ; i < 10 ; ++i) {
116- printf (" %4.2f, " , embd[i]);
174+ // prepare for the next iteration
175+ {
176+ batch_past_embd.n_tokens = 1 ;
177+ batch_past_embd.pos [0 ] = n_past_bb;
178+ batch_past_embd.seq_id [0 ][0 ] = 0 ;
179+ batch_past_embd.n_seq_id [0 ] = 1 ;
180+ batch_past_embd.logits [0 ] = true ;
181+ memcpy (batch_past_embd.embd , inp_past_embd.data (), inp_past_embd.size () * sizeof (float ));
182+ }
183+ n_past_bb++;
117184 }
118185
119186 return 0 ;
0 commit comments