Skip to content

Commit 4012054

Browse files
committed
clean up
1 parent f9162e7 commit 4012054

File tree

3 files changed

+120
-47
lines changed

3 files changed

+120
-47
lines changed

examples/tts/README-csm.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Sesame CSM
2+
3+
To get the GGUF:
4+
5+
```sh
6+
python examples/tts/convert_csm_to_gguf.py
7+
8+
# default output files:
9+
# sesame-csm-backbone.gguf
10+
# sesame-csm-decoder.gguf
11+
12+
# optionally, quantize it
13+
# (lowest scheme is q8_0, it does not make sense to quantize further, quality degrades too much)
14+
python examples/tts/convert_csm_to_gguf.py --outtype q8_0
15+
```
16+
17+
Compile the example:
18+
19+
```sh
20+
cmake --build build -j --target llama-tts-csm
21+
```
22+
23+
Run the example:
24+
25+
```sh
26+
./build/bin/llama-tts-csm -m sesame-csm-backbone.gguf -p "[0]Hello world."
27+
# sesame-csm-backbone.gguf will automatically be loaded
28+
# make sure the place these 2 GGUF files in the same directory
29+
30+
# output file: output.wav
31+
```

examples/tts/convert_csm_to_gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(self,
9595
fname_out: Path,
9696
ftype: gguf.LlamaFileType,
9797
is_big_endian: bool,):
98-
98+
9999
if "<component>" not in fname_out.name:
100100
raise ValueError("Output file name must contain '<component>' placeholder, for example: 'sesame-csm-<component>.gguf'")
101101

examples/tts/tts-csm.cpp

Lines changed: 88 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
#include <vector>
77
#include <fstream>
88
#include <float.h>
9+
#include <cstring> // memcpy and strcmp
10+
#include <inttypes.h>
11+
12+
// For more details on how this works, see: https://github.com/ggml-org/llama.cpp/pull/12648
913

1014
static void print_usage(int, char ** argv) {
1115
LOG("\nexample usage:\n");
@@ -30,6 +34,8 @@ static llama_token sample_greedy(const float * logits, int n_vocab) {
3034
static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) {
3135
std::vector<float> * embd = (std::vector<float> *) user_data;
3236

37+
// output_csm_proj is the embeddings output from backbone
38+
// output_audio_embd is the embeddings output from decoder
3339
if (t && (strcmp(t->name, "output_csm_proj") == 0 || strcmp(t->name, "output_audio_embd") == 0)) {
3440
if (ask) return true;
3541

@@ -45,13 +51,10 @@ static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) {
4551
int main(int argc, char ** argv) {
4652
common_params params;
4753

48-
params.model = "sesame-csm-backbone.gguf";
49-
params.out_file = "output.wav";
50-
params.prompt = "[0]Hello from Sesame.";
51-
52-
params.n_predict = 4096;
53-
params.n_batch = 8192;
54-
params.n_ctx = 8192;
54+
params.model = "sesame-csm-backbone.gguf";
55+
params.out_file = "output.wav";
56+
params.prompt = "[0]Hello from Sesame.";
57+
params.n_predict = 2048; // CSM's max trained seq length
5558

5659
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) {
5760
return 1;
@@ -66,6 +69,7 @@ int main(int argc, char ** argv) {
6669
params.warmup = false;
6770

6871
common_params params_decoder(params); // duplicate the params
72+
params_decoder.n_ctx = 64; // we never use more than this
6973
string_replace_all(params_decoder.model, "-backbone", "-decoder");
7074

7175
common_init_result llama_backbone = common_init_from_params(params);
@@ -96,77 +100,114 @@ int main(int argc, char ** argv) {
96100
printf("\n");
97101

98102
llama_pos n_past_bb = 0;
99-
llama_batch batch = llama_batch_init(params.n_batch, 0, 1);
100-
common_batch_clear(batch);
103+
llama_batch batch_prompt = llama_batch_init(params.n_batch, 0, 1);
104+
common_batch_clear(batch_prompt);
101105
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
102-
common_batch_add(batch, prompt_tokens[i], n_past_bb++, { 0 }, false);
106+
common_batch_add(batch_prompt, prompt_tokens[i], n_past_bb++, { 0 }, false);
103107
}
104-
batch.logits[batch.n_tokens - 1] = true;
108+
batch_prompt.logits[batch_prompt.n_tokens - 1] = true;
105109

110+
// inp_past_embd is the "squashed" embeddings from the decoder
106111
std::vector<float> inp_past_embd(2048, 0.0f);
107112
llama_batch batch_past_embd = llama_batch_init(1, inp_past_embd.size(), 1);
108113

109-
for (int k = 0; k < 32; ++k) {
110-
if (llama_decode(ctx_bb, k == 0 ? batch : batch_past_embd) != 0) {
111-
LOG_ERR("%s: llama_decode() failed\n", __func__);
114+
int64_t t_gb_start = ggml_time_ms(); // global start time
115+
int64_t t_bb = 0; // backbone time
116+
int64_t n_bb_gen = 0; // backbone generation count
117+
int64_t t_dc = 0; // decoder time
118+
int64_t n_dc_gen = 0; // decoder generation count
119+
120+
bool is_stop = false;
121+
122+
// backbone generation loop
123+
for (int k = 0; k < params.n_predict; ++k) {
124+
bool is_prompt_processing = k == 0;
125+
126+
if (!is_prompt_processing) {
127+
// generate the next RVQ semantic token
128+
batch_past_embd.n_tokens = 1;
129+
batch_past_embd.pos[0] = n_past_bb++;
130+
batch_past_embd.seq_id[0][0] = 0;
131+
batch_past_embd.n_seq_id[0] = 1;
132+
batch_past_embd.logits[0] = true;
133+
std::memcpy(batch_past_embd.embd, inp_past_embd.data(), inp_past_embd.size() * sizeof(float));
134+
}
135+
136+
int64_t t_bb_start = ggml_time_ms();
137+
if (llama_decode(ctx_bb, is_prompt_processing ? batch_prompt : batch_past_embd) != 0) {
138+
LOG_ERR("%s: backbone llama_decode() failed\n", __func__);
112139
return 1;
113140
}
141+
n_bb_gen++;
142+
t_bb += ggml_time_ms() - t_bb_start;
114143

115144
auto vocab_dc = llama_model_get_vocab(model_dc);
116-
auto logits = llama_get_logits_ith(ctx_bb, k == 0 ? (batch.n_tokens - 1) : 0);
145+
auto logits = llama_get_logits_ith(ctx_bb, is_prompt_processing ? (batch_prompt.n_tokens - 1) : 0);
117146
// for (size_t i = 0; i < 10; ++i) {
118147
// printf("%4.2f, ", logits[i]);
119148
// }
120149
// printf("\n");
121150

122-
llama_token latent_token = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc));
123-
// printf("latent_token: %d\n", latent_token);
124-
printf("%d,", latent_token);
151+
llama_token semantic_tok = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc));
152+
printf("%d,", semantic_tok);
125153

126154
// for (size_t i = 0; i < 10; ++i) {
127155
// printf("%4.2f, ", embd[i]);
128156
// }
129157
// printf("\n");
130158

131-
132159

133-
// decode
134-
prompt_tokens.clear();
135-
prompt_tokens.push_back(latent_token);
160+
// decoder generation loop
136161
inp_past_embd = std::vector<float>(inp_past_embd.size(), 0.0f);
137162
{
138163
llama_kv_self_clear(ctx_dc);
139164
llama_batch batch_embd = llama_batch_init(1, embd.size(), 1);
140165
llama_batch batch_token = llama_batch_init(1, 0, 1);
166+
167+
// first "token" is the latent embeddings from backbone
141168
{
142169
batch_embd.n_tokens = 1;
143170
batch_embd.pos[0] = 0;
144171
batch_embd.seq_id[0][0] = 0;
145172
batch_embd.n_seq_id[0] = 1;
146173
batch_embd.logits[0] = false;
147-
memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float));
174+
std::memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float));
175+
}
176+
if (llama_decode(ctx_dc, batch_embd) != 0) {
177+
LOG_ERR("%s: decoder llama_decode(embd) failed\n", __func__);
178+
return 1;
148179
}
149-
llama_decode(ctx_dc, batch_embd);
150-
151-
llama_token audio_token = latent_token;
180+
181+
// then, decode the semantic_tok to generate acoustic tokens
182+
llama_token tok = semantic_tok;
152183
int n_codes = 32;
153-
int sum_codes = 0;
184+
int sum_codes = 0; // to check if all codes are 0
154185
for (int i = 0; i < n_codes; ++i) {
155186
common_batch_clear(batch_token);
156187
// encoder vocab is further divided into 32 codebooks, each with 2051 entries
157-
llama_token inp_tok = audio_token + 2051*i;
188+
llama_token inp_tok = tok + 2051*i;
158189
common_batch_add(batch_token, inp_tok, i+1, { 0 }, true);
159-
llama_decode(ctx_dc, batch_token);
190+
191+
int64_t t_bb_start = ggml_time_ms();
192+
if (llama_decode(ctx_dc, batch_token) != 0) {
193+
LOG_ERR("%s: decoder llama_decode(token) failed\n", __func__);
194+
return 1;
195+
}
196+
n_dc_gen++;
197+
t_dc += ggml_time_ms() - t_bb_start;
198+
199+
// sample the acoustic token
160200
auto logits = llama_get_logits_ith(ctx_dc, 0);
161-
audio_token = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc));
201+
llama_token acoustic_tok = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc));
162202

163-
// discard last code
203+
// discard last code (only for embeddings)
164204
if (i < n_codes - 1) {
165-
printf("%d,", audio_token);
166-
prompt_tokens.push_back(audio_token);
167-
sum_codes += audio_token;
205+
printf("%d,", acoustic_tok);
206+
tok = acoustic_tok; // next input token
207+
sum_codes += acoustic_tok;
168208
}
169209

210+
// do progressive hsum of embeddings
170211
GGML_ASSERT(inp_past_embd.size() == embd.size());
171212
for (size_t i = 0; i < inp_past_embd.size(); ++i) {
172213
inp_past_embd[i] += embd[i];
@@ -177,9 +218,8 @@ int main(int argc, char ** argv) {
177218
llama_batch_free(batch_embd);
178219
llama_batch_free(batch_token);
179220

180-
if (sum_codes == 0) {
181-
return 0; // done
182-
}
221+
// if all codes are 0, then we are done
222+
is_stop = sum_codes == 0;
183223
}
184224

185225
// printf("inp_past_embd, n_past_bb = %d\n", n_past_bb);
@@ -192,17 +232,19 @@ int main(int argc, char ** argv) {
192232
// }
193233
// printf("\n");
194234

195-
// prepare for the next iteration
196-
{
197-
batch_past_embd.n_tokens = 1;
198-
batch_past_embd.pos[0] = n_past_bb;
199-
batch_past_embd.seq_id[0][0] = 0;
200-
batch_past_embd.n_seq_id[0] = 1;
201-
batch_past_embd.logits[0] = true;
202-
memcpy(batch_past_embd.embd, inp_past_embd.data(), inp_past_embd.size() * sizeof(float));
235+
if (is_stop) {
236+
break;
203237
}
204-
n_past_bb++;
205238
}
206239

240+
// print timing info
241+
printf("\ntimings:\n");
242+
printf(" backbone: %" PRId64 " ms, %" PRId64 " generated token (%.2f tok/s)\n", t_bb, n_bb_gen, (float)n_bb_gen*1000/(float)t_bb);
243+
printf(" decoder: %" PRId64 " ms, %" PRId64 " generated token (%.2f tok/s)\n", t_dc, n_dc_gen, (float)n_dc_gen*1000/(float)t_dc);
244+
printf(" total: %" PRId64 " ms\n\n", ggml_time_ms() - t_gb_start);
245+
246+
llama_batch_free(batch_prompt);
247+
llama_batch_free(batch_past_embd);
248+
207249
return 0;
208250
}

0 commit comments

Comments
 (0)