Skip to content

Commit d178099

Browse files
committed
add audio EOS token
1 parent 142b545 commit d178099

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

examples/tts/tts-csm.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,11 @@ int main(int argc, char ** argv) {
225225
n_bb_gen++;
226226
t_bb += ggml_time_ms() - t_bb_start;
227227

228+
if (is_end_of_turn) {
229+
// done decoding audio's EOS token
230+
break;
231+
}
232+
228233
auto vocab_dc = llama_model_get_vocab(model_dc);
229234
auto logits = llama_get_logits_ith(ctx_bb, is_prompt_processing ? (batch_prompt.n_tokens - 1) : 0);
230235
// for (size_t i = 0; i < 10; ++i) {
@@ -304,8 +309,13 @@ int main(int argc, char ** argv) {
304309
llama_batch_free(batch_embd);
305310
llama_batch_free(batch_token);
306311

307-
// if all codes are 0, then we are done
312+
// if all codes are 0, then we are done (got audio EOS token)
313+
// note: we still need to run backbone decode one more time to decode the audio's EOS token
308314
is_end_of_turn = sum_codes == 0;
315+
if (is_end_of_turn) {
316+
// remove last 32 codes since they will be all zeros
317+
generated_codes.resize(generated_codes.size() - 32);
318+
}
309319
}
310320

311321
// printf("inp_past_embd, n_past_bb = %d\n", n_past_bb);
@@ -317,12 +327,6 @@ int main(int argc, char ** argv) {
317327
// }
318328
// }
319329
// printf("\n");
320-
321-
if (is_end_of_turn) {
322-
// remove last 32 codes since they will be all zeros
323-
generated_codes.resize(generated_codes.size() - 32);
324-
break;
325-
}
326330
}
327331
}
328332

0 commit comments

Comments
 (0)