Skip to content

Commit c681257

Browse files
committed
ability to do multi-turns
1 parent e9dc476 commit c681257

File tree

2 files changed

+171
-127
lines changed

2 files changed

+171
-127
lines changed

examples/tts/csm-demo.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[0]Hey how are you doing.
2+
[1]Pretty good, pretty good.
3+
[0]I'm great, so happy to be speaking to you.
4+
What about you?
5+
[1]Me too, this is some cool stuff huh?

examples/tts/tts-csm.cpp

Lines changed: 166 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "mimi-model.h"
66

77
#include <vector>
8+
#include <regex>
89
#include <fstream>
910
#include <float.h>
1011
#include <cstring> // memcpy and strcmp
@@ -23,12 +24,39 @@ static void print_usage(int, char ** argv) {
2324
LOG("\n Note: the model need 2 files to run, one ends with '-backbone-<quant>.gguf' and the other ends with '-decoder<quant>.gguf'");
2425
LOG("\n");
2526
LOG("\nPrompt format:");
26-
LOG("\n Each line must start with speaker ID in square brackets, followed by the text. A full stop is recommended at the end of each turn");
27-
LOG("\n Example: [0]Hello world.");
27+
LOG("\n Each line must start with speaker ID in square brackets, followed by the text. One turn per line. A full stop is recommended at the end of each turn");
28+
LOG("\n Example:");
29+
LOG("\n [0]Hey how are you doing.");
30+
LOG("\n [1]Pretty good, pretty good.");
2831
LOG("\n If you want to enter long text, use -f file.txt to read from file");
2932
LOG("\n");
3033
}
3134

35+
// split text containing "[N]..." into speaker turns
36+
static std::vector<std::string> get_speaker_turns(const std::string & input) {
37+
if (input.empty()) {
38+
LOG_ERR("Empty input\n");
39+
return {};
40+
}
41+
if (input[0] != '[') {
42+
LOG_ERR("Invalid input format: missing speaker ID\n");
43+
return {};
44+
}
45+
std::regex re(R"((\[\d+\][\s\S]*?)(?=\[\d+\]|$))");
46+
std::smatch match;
47+
std::vector<std::string> turns;
48+
std::string::const_iterator searchStart(input.cbegin());
49+
while (std::regex_search(searchStart, input.cend(), match, re)) {
50+
std::string turn = match[1].str();
51+
if (turn.empty()) {
52+
continue;
53+
}
54+
turns.push_back(turn);
55+
searchStart = match.suffix().first;
56+
}
57+
return turns;
58+
}
59+
3260
// sampling with custom n_vocab
3361
// modified version of llama_sampler_sample()
3462
static llama_token sample_token(struct llama_sampler * smpl, const float * logits, int n_vocab) {
@@ -81,9 +109,11 @@ int main(int argc, char ** argv) {
81109
params.sampling.top_k = 50; // default param from CSM python code
82110
params.sampling.temp = 0.9; // default param from CSM python code
83111

84-
// HF model
85-
params.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-backbone.gguf";
86-
params.vocoder.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/kyutai-mimi.gguf";
112+
// HF model (hack: we temporary reuse speculative.model as the decoder model, only to get it downloaded)
113+
params.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-backbone.gguf";
114+
params.speculative.model.path = "sesame-csm-decoder.gguf";
115+
params.speculative.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-decoder.gguf";
116+
params.vocoder.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/kyutai-mimi.gguf";
87117

88118
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) {
89119
return 1;
@@ -125,32 +155,15 @@ int main(int argc, char ** argv) {
125155

126156
mimi_model mimi(params.vocoder.model.path.c_str(), true);
127157

128-
// tokenize the prompt
129-
const llama_vocab * vocab = llama_model_get_vocab(model_bb);
130-
llama_tokens prompt_tokens = common_tokenize(vocab, params.prompt, false, true);
131-
prompt_tokens.insert(prompt_tokens.begin(), llama_vocab_bos(vocab));
132-
prompt_tokens.insert(prompt_tokens.end(), llama_vocab_eos(vocab));
133-
134158
// init sampler
135159
// the python implementation only has top-k and temperature sampling, so we'll use just that
136160
llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params()));
137161
llama_sampler_chain_add(sampler.get(), llama_sampler_init_top_k(params.sampling.top_k));
138162
llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(params.sampling.temp));
139163
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(params.sampling.seed));
140164

141-
printf("prompt tokens: \n");
142-
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
143-
printf("%d, ", prompt_tokens[i]);
144-
}
145-
printf("\n");
146-
147-
llama_pos n_past_bb = 0;
148165
llama_batch batch_prompt = llama_batch_init(params.n_batch, 0, 1);
149-
common_batch_clear(batch_prompt);
150-
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
151-
common_batch_add(batch_prompt, prompt_tokens[i], n_past_bb++, { 0 }, false);
152-
}
153-
batch_prompt.logits[batch_prompt.n_tokens - 1] = true;
166+
llama_pos n_past_bb = 0;
154167

155168
// inp_past_embd is the "squashed" embeddings from the decoder
156169
std::vector<float> inp_past_embd(2048, 0.0f);
@@ -162,128 +175,154 @@ int main(int argc, char ** argv) {
162175
int64_t t_dc = 0; // decoder time
163176
int64_t n_dc_gen = 0; // decoder generation count
164177

165-
bool is_stop = false;
166178
std::vector<int> generated_codes;
167179

168-
// backbone generation loop
169-
for (int k = 0; k < params.n_predict; ++k) {
170-
bool is_prompt_processing = k == 0;
171-
172-
if (!is_prompt_processing) {
173-
// generate the next RVQ semantic token
174-
batch_past_embd.n_tokens = 1;
175-
batch_past_embd.pos[0] = n_past_bb++;
176-
batch_past_embd.seq_id[0][0] = 0;
177-
batch_past_embd.n_seq_id[0] = 1;
178-
batch_past_embd.logits[0] = true;
179-
std::memcpy(batch_past_embd.embd, inp_past_embd.data(), inp_past_embd.size() * sizeof(float));
180-
}
180+
auto turns = get_speaker_turns(params.prompt);
181+
182+
for (const std::string & turn : turns) {
183+
// tokenize the turn
184+
llama_tokens prompt_tokens;
185+
{
186+
printf("\n---\nturn: %s\n\n", turn.c_str());
187+
const llama_vocab * vocab = llama_model_get_vocab(model_bb);
188+
prompt_tokens = common_tokenize(vocab, turn, false, true);
189+
prompt_tokens.insert(prompt_tokens.begin(), llama_vocab_bos(vocab));
190+
prompt_tokens.insert(prompt_tokens.end(), llama_vocab_eos(vocab));
191+
192+
printf("prompt (%zu tokens): \n", prompt_tokens.size());
193+
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
194+
printf("%d, ", prompt_tokens[i]);
195+
}
196+
printf("\n");
181197

182-
int64_t t_bb_start = ggml_time_ms();
183-
if (llama_decode(ctx_bb, is_prompt_processing ? batch_prompt : batch_past_embd) != 0) {
184-
LOG_ERR("%s: backbone llama_decode() failed\n", __func__);
185-
return 1;
198+
common_batch_clear(batch_prompt);
199+
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
200+
common_batch_add(batch_prompt, prompt_tokens[i], n_past_bb++, { 0 }, false);
201+
}
202+
batch_prompt.logits[batch_prompt.n_tokens - 1] = true;
186203
}
187-
n_bb_gen++;
188-
t_bb += ggml_time_ms() - t_bb_start;
189204

190-
auto vocab_dc = llama_model_get_vocab(model_dc);
191-
auto logits = llama_get_logits_ith(ctx_bb, is_prompt_processing ? (batch_prompt.n_tokens - 1) : 0);
192-
// for (size_t i = 0; i < 10; ++i) {
193-
// printf("%4.2f, ", logits[i]);
194-
// }
195-
// printf("\n");
205+
// backbone generation loop
206+
bool is_end_of_turn = false;
207+
for (int k = 0; k < params.n_predict; ++k) {
208+
bool is_prompt_processing = k == 0;
209+
210+
if (!is_prompt_processing) {
211+
// generate the next RVQ semantic token
212+
batch_past_embd.n_tokens = 1;
213+
batch_past_embd.pos[0] = n_past_bb++;
214+
batch_past_embd.seq_id[0][0] = 0;
215+
batch_past_embd.n_seq_id[0] = 1;
216+
batch_past_embd.logits[0] = true;
217+
std::memcpy(batch_past_embd.embd, inp_past_embd.data(), inp_past_embd.size() * sizeof(float));
218+
}
196219

197-
llama_token semantic_tok = sample_token(sampler.get(), logits, llama_vocab_n_tokens(vocab_dc));
198-
printf("Sem token %5d : %d,", 1+(int)generated_codes.size()/32, semantic_tok);
199-
generated_codes.push_back(semantic_tok);
220+
int64_t t_bb_start = ggml_time_ms();
221+
if (llama_decode(ctx_bb, is_prompt_processing ? batch_prompt : batch_past_embd) != 0) {
222+
LOG_ERR("%s: backbone llama_decode() failed\n", __func__);
223+
return 1;
224+
}
225+
n_bb_gen++;
226+
t_bb += ggml_time_ms() - t_bb_start;
200227

201-
// for (size_t i = 0; i < 10; ++i) {
202-
// printf("%4.2f, ", embd[i]);
203-
// }
204-
// printf("\n");
228+
auto vocab_dc = llama_model_get_vocab(model_dc);
229+
auto logits = llama_get_logits_ith(ctx_bb, is_prompt_processing ? (batch_prompt.n_tokens - 1) : 0);
230+
// for (size_t i = 0; i < 10; ++i) {
231+
// printf("%4.2f, ", logits[i]);
232+
// }
233+
// printf("\n");
205234

235+
llama_token semantic_tok = sample_token(sampler.get(), logits, llama_vocab_n_tokens(vocab_dc));
236+
printf("Sem token %5d : %d,", 1+(int)generated_codes.size()/32, semantic_tok);
237+
generated_codes.push_back(semantic_tok);
206238

207-
// decoder generation loop
208-
inp_past_embd = std::vector<float>(inp_past_embd.size(), 0.0f);
209-
{
210-
llama_kv_self_clear(ctx_dc);
211-
llama_batch batch_embd = llama_batch_init(1, embd.size(), 1);
212-
llama_batch batch_token = llama_batch_init(1, 0, 1);
239+
// for (size_t i = 0; i < 10; ++i) {
240+
// printf("%4.2f, ", embd[i]);
241+
// }
242+
// printf("\n");
213243

214-
// first "token" is the latent embeddings from backbone
215-
{
216-
batch_embd.n_tokens = 1;
217-
batch_embd.pos[0] = 0;
218-
batch_embd.seq_id[0][0] = 0;
219-
batch_embd.n_seq_id[0] = 1;
220-
batch_embd.logits[0] = false;
221-
std::memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float));
222-
}
223-
if (llama_decode(ctx_dc, batch_embd) != 0) {
224-
LOG_ERR("%s: decoder llama_decode(embd) failed\n", __func__);
225-
return 1;
226-
}
227244

228-
// then, decode the semantic_tok to generate acoustic tokens
229-
llama_token tok = semantic_tok;
230-
int n_codes = 32;
231-
int sum_codes = semantic_tok; // to check if all codes are 0
232-
for (int i = 0; i < n_codes; ++i) {
233-
common_batch_clear(batch_token);
234-
// encoder vocab is further divided into 32 codebooks, each with 2051 entries
235-
llama_token inp_tok = tok + 2051*i;
236-
common_batch_add(batch_token, inp_tok, i+1, { 0 }, true);
237-
238-
int64_t t_bb_start = ggml_time_ms();
239-
if (llama_decode(ctx_dc, batch_token) != 0) {
240-
LOG_ERR("%s: decoder llama_decode(token) failed\n", __func__);
241-
return 1;
245+
// decoder generation loop
246+
inp_past_embd = std::vector<float>(inp_past_embd.size(), 0.0f);
247+
{
248+
llama_kv_self_clear(ctx_dc);
249+
llama_batch batch_embd = llama_batch_init(1, embd.size(), 1);
250+
llama_batch batch_token = llama_batch_init(1, 0, 1);
251+
252+
// first "token" is the latent embeddings from backbone
253+
{
254+
batch_embd.n_tokens = 1;
255+
batch_embd.pos[0] = 0;
256+
batch_embd.seq_id[0][0] = 0;
257+
batch_embd.n_seq_id[0] = 1;
258+
batch_embd.logits[0] = false;
259+
std::memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float));
242260
}
243-
n_dc_gen++;
244-
t_dc += ggml_time_ms() - t_bb_start;
245-
246-
// sample the acoustic token
247-
auto logits = llama_get_logits_ith(ctx_dc, 0);
248-
llama_token acoustic_tok = sample_token(sampler.get(), logits, llama_vocab_n_tokens(vocab_dc));
249-
250-
// discard last code (only for embeddings)
251-
if (i < n_codes - 1) {
252-
printf("%d,", acoustic_tok);
253-
tok = acoustic_tok; // next input token
254-
sum_codes += acoustic_tok;
255-
generated_codes.push_back(acoustic_tok);
261+
if (llama_decode(ctx_dc, batch_embd) != 0) {
262+
LOG_ERR("%s: decoder llama_decode(embd) failed\n", __func__);
263+
return 1;
256264
}
257265

258-
// do progressive hsum of embeddings
259-
GGML_ASSERT(inp_past_embd.size() == embd.size());
260-
for (size_t i = 0; i < inp_past_embd.size(); ++i) {
261-
inp_past_embd[i] += embd[i];
266+
// then, decode the semantic_tok to generate acoustic tokens
267+
llama_token tok = semantic_tok;
268+
int n_codes = 32;
269+
int sum_codes = semantic_tok; // to check if all codes are 0
270+
for (int i = 0; i < n_codes; ++i) {
271+
common_batch_clear(batch_token);
272+
// encoder vocab is further divided into 32 codebooks, each with 2051 entries
273+
llama_token inp_tok = tok + 2051*i;
274+
common_batch_add(batch_token, inp_tok, i+1, { 0 }, true);
275+
276+
int64_t t_bb_start = ggml_time_ms();
277+
if (llama_decode(ctx_dc, batch_token) != 0) {
278+
LOG_ERR("%s: decoder llama_decode(token) failed\n", __func__);
279+
return 1;
280+
}
281+
n_dc_gen++;
282+
t_dc += ggml_time_ms() - t_bb_start;
283+
284+
// sample the acoustic token
285+
auto logits = llama_get_logits_ith(ctx_dc, 0);
286+
llama_token acoustic_tok = sample_token(sampler.get(), logits, llama_vocab_n_tokens(vocab_dc));
287+
288+
// discard last code (only for embeddings)
289+
if (i < n_codes - 1) {
290+
printf("%d,", acoustic_tok);
291+
tok = acoustic_tok; // next input token
292+
sum_codes += acoustic_tok;
293+
generated_codes.push_back(acoustic_tok);
294+
}
295+
296+
// do progressive hsum of embeddings
297+
GGML_ASSERT(inp_past_embd.size() == embd.size());
298+
for (size_t i = 0; i < inp_past_embd.size(); ++i) {
299+
inp_past_embd[i] += embd[i];
300+
}
262301
}
263-
}
264-
printf("\n");
302+
printf("\n");
265303

266-
llama_batch_free(batch_embd);
267-
llama_batch_free(batch_token);
304+
llama_batch_free(batch_embd);
305+
llama_batch_free(batch_token);
268306

269-
// if all codes are 0, then we are done
270-
is_stop = sum_codes == 0;
271-
}
307+
// if all codes are 0, then we are done
308+
is_end_of_turn = sum_codes == 0;
309+
}
272310

273-
// printf("inp_past_embd, n_past_bb = %d\n", n_past_bb);
274-
// for (size_t i = 0; i < inp_past_embd.size(); ++i) {
275-
// printf("%4.4f, ", inp_past_embd[i]);
276-
// if (i == 2) {
277-
// printf("... ");
278-
// i = inp_past_embd.size() - 4;
279-
// }
280-
// }
281-
// printf("\n");
282-
283-
if (is_stop) {
284-
// remove last 32 codes since they will be all zeros
285-
generated_codes.resize(generated_codes.size() - 32);
286-
break;
311+
// printf("inp_past_embd, n_past_bb = %d\n", n_past_bb);
312+
// for (size_t i = 0; i < inp_past_embd.size(); ++i) {
313+
// printf("%4.4f, ", inp_past_embd[i]);
314+
// if (i == 2) {
315+
// printf("... ");
316+
// i = inp_past_embd.size() - 4;
317+
// }
318+
// }
319+
// printf("\n");
320+
321+
if (is_end_of_turn) {
322+
// remove last 32 codes since they will be all zeros
323+
generated_codes.resize(generated_codes.size() - 32);
324+
break;
325+
}
287326
}
288327
}
289328

0 commit comments

Comments
 (0)