@@ -106,7 +106,7 @@ int main(int argc, char ** argv) {
106106 std::vector<float > inp_past_embd (2048 , 0 .0f );
107107 llama_batch batch_past_embd = llama_batch_init (1 , inp_past_embd.size (), 1 );
108108
109- for (int k = 0 ; k < 4 ; ++k) {
109+ for (int k = 0 ; k < 32 ; ++k) {
110110 if (llama_decode (ctx_bb, k == 0 ? batch : batch_past_embd) != 0 ) {
111111 LOG_ERR (" %s: llama_decode() failed\n " , __func__);
112112 return 1 ;
@@ -121,7 +121,7 @@ int main(int argc, char ** argv) {
121121
122122 llama_token latent_token = sample_greedy (logits, llama_vocab_n_tokens (vocab_dc));
123123 // printf("latent_token: %d\n", latent_token);
124- printf (" %5d, " , latent_token);
124+ printf (" %d, " , latent_token);
125125
126126 // for (size_t i = 0; i < 10; ++i) {
127127 // printf("%4.2f, ", embd[i]);
@@ -149,16 +149,23 @@ int main(int argc, char ** argv) {
149149 llama_decode (ctx_dc, batch_embd);
150150
151151 llama_token audio_token = latent_token;
152- for (int i = 0 ; i < 31 ; ++i) {
152+ int n_codes = 32 ;
153+ int sum_codes = 0 ;
154+ for (int i = 0 ; i < n_codes; ++i) {
153155 common_batch_clear (batch_token);
154156 // encoder vocab is further divided into 32 codebooks, each with 2051 entries
155157 llama_token inp_tok = audio_token + 2051 *i;
156158 common_batch_add (batch_token, inp_tok, i+1 , { 0 }, true );
157159 llama_decode (ctx_dc, batch_token);
158160 auto logits = llama_get_logits_ith (ctx_dc, 0 );
159161 audio_token = sample_greedy (logits, llama_vocab_n_tokens (vocab_dc));
160- printf (" %d," , audio_token);
161- prompt_tokens.push_back (audio_token);
162+
163+ // discard last code
164+ if (i < n_codes - 1 ) {
165+ printf (" %d," , audio_token);
166+ prompt_tokens.push_back (audio_token);
167+ sum_codes += audio_token;
168+ }
162169
163170 GGML_ASSERT (inp_past_embd.size () == embd.size ());
164171 for (size_t i = 0 ; i < inp_past_embd.size (); ++i) {
@@ -169,8 +176,22 @@ int main(int argc, char ** argv) {
169176
170177 llama_batch_free (batch_embd);
171178 llama_batch_free (batch_token);
179+
180+ if (sum_codes == 0 ) {
181+ return 0 ; // done
182+ }
172183 }
173184
185+ // printf("inp_past_embd, n_past_bb = %d\n", n_past_bb);
186+ // for (size_t i = 0; i < inp_past_embd.size(); ++i) {
187+ // printf("%4.4f, ", inp_past_embd[i]);
188+ // if (i == 2) {
189+ // printf("... ");
190+ // i = inp_past_embd.size() - 4;
191+ // }
192+ // }
193+ // printf("\n");
194+
174195 // prepare for the next iteration
175196 {
176197 batch_past_embd.n_tokens = 1 ;
0 commit comments