Skip to content

Commit f9162e7

Browse files
committed
wip
1 parent 2d743b6 commit f9162e7

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

examples/tts/tts-csm.cpp

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ int main(int argc, char ** argv) {
106106
std::vector<float> inp_past_embd(2048, 0.0f);
107107
llama_batch batch_past_embd = llama_batch_init(1, inp_past_embd.size(), 1);
108108

109-
for (int k = 0; k < 4; ++k) {
109+
for (int k = 0; k < 32; ++k) {
110110
if (llama_decode(ctx_bb, k == 0 ? batch : batch_past_embd) != 0) {
111111
LOG_ERR("%s: llama_decode() failed\n", __func__);
112112
return 1;
@@ -121,7 +121,7 @@ int main(int argc, char ** argv) {
121121

122122
llama_token latent_token = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc));
123123
// printf("latent_token: %d\n", latent_token);
124-
printf("%5d, ", latent_token);
124+
printf("%d,", latent_token);
125125

126126
// for (size_t i = 0; i < 10; ++i) {
127127
// printf("%4.2f, ", embd[i]);
@@ -149,16 +149,23 @@ int main(int argc, char ** argv) {
149149
llama_decode(ctx_dc, batch_embd);
150150

151151
llama_token audio_token = latent_token;
152-
for (int i = 0; i < 31; ++i) {
152+
int n_codes = 32;
153+
int sum_codes = 0;
154+
for (int i = 0; i < n_codes; ++i) {
153155
common_batch_clear(batch_token);
154156
// encoder vocab is further divided into 32 codebooks, each with 2051 entries
155157
llama_token inp_tok = audio_token + 2051*i;
156158
common_batch_add(batch_token, inp_tok, i+1, { 0 }, true);
157159
llama_decode(ctx_dc, batch_token);
158160
auto logits = llama_get_logits_ith(ctx_dc, 0);
159161
audio_token = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc));
160-
printf("%d,", audio_token);
161-
prompt_tokens.push_back(audio_token);
162+
163+
// discard last code
164+
if (i < n_codes - 1) {
165+
printf("%d,", audio_token);
166+
prompt_tokens.push_back(audio_token);
167+
sum_codes += audio_token;
168+
}
162169

163170
GGML_ASSERT(inp_past_embd.size() == embd.size());
164171
for (size_t i = 0; i < inp_past_embd.size(); ++i) {
@@ -169,8 +176,22 @@ int main(int argc, char ** argv) {
169176

170177
llama_batch_free(batch_embd);
171178
llama_batch_free(batch_token);
179+
180+
if (sum_codes == 0) {
181+
return 0; // done
182+
}
172183
}
173184

185+
// printf("inp_past_embd, n_past_bb = %d\n", n_past_bb);
186+
// for (size_t i = 0; i < inp_past_embd.size(); ++i) {
187+
// printf("%4.4f, ", inp_past_embd[i]);
188+
// if (i == 2) {
189+
// printf("... ");
190+
// i = inp_past_embd.size() - 4;
191+
// }
192+
// }
193+
// printf("\n");
194+
174195
// prepare for the next iteration
175196
{
176197
batch_past_embd.n_tokens = 1;

0 commit comments

Comments
 (0)