Skip to content

Commit 5be8e7d

Browse files
committed
add top-k and temp sampling
1 parent e31a75c commit 5be8e7d

File tree

1 file changed

+37
-17
lines changed

1 file changed

+37
-17
lines changed

examples/tts/tts-csm.cpp

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,27 @@ static void print_usage(int, char ** argv) {
2929
LOG("\n");
3030
}
3131

32-
// greedy sampling with custom n_vocab
33-
static llama_token sample_greedy(const float * logits, int n_vocab) {
34-
llama_token max_idx = -1;
35-
float max_val = -FLT_MAX;
36-
for (int i = 0; i < n_vocab; ++i) {
37-
if (logits[i] > max_val) {
38-
max_val = logits[i];
39-
max_idx = i;
40-
}
32+
// sampling with custom n_vocab
33+
// modified version of llama_sampler_sample()
34+
static llama_token sample_token(struct llama_sampler * smpl, const float * logits, int n_vocab) {
35+
std::vector<llama_token_data> cur;
36+
cur.reserve(n_vocab);
37+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
38+
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
4139
}
42-
return max_idx;
40+
41+
llama_token_data_array cur_p = {
42+
/* .data = */ cur.data(),
43+
/* .size = */ cur.size(),
44+
/* .selected = */ -1,
45+
/* .sorted = */ false,
46+
};
47+
48+
llama_sampler_apply(smpl, &cur_p);
49+
GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
50+
auto token = cur_p.data[cur_p.selected].id;
51+
llama_sampler_accept(smpl, token);
52+
return token;
4353
}
4454

4555
// hook to retrieve the embeddings
@@ -63,11 +73,13 @@ static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) {
6373
int main(int argc, char ** argv) {
6474
common_params params;
6575

66-
params.model = "sesame-csm-backbone.gguf";
67-
params.vocoder.model = "kyutai-mimi.gguf";
68-
params.out_file = "output.wav";
69-
params.prompt = "";
70-
params.n_predict = 2048; // CSM's max trained seq length
76+
params.model = "sesame-csm-backbone.gguf";
77+
params.vocoder.model = "kyutai-mimi.gguf";
78+
params.out_file = "output.wav";
79+
params.prompt = "";
80+
params.n_predict = 2048; // CSM's max trained seq length
81+
params.sampling.top_k = 50; // default param from CSM python code
82+
params.sampling.temp = 0.9; // default param from CSM python code
7183

7284
// HF model
7385
params.model_url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-backbone.gguf";
@@ -115,11 +127,19 @@ int main(int argc, char ** argv) {
115127

116128
mimi_model mimi(params.vocoder.model.c_str(), true);
117129

130+
// tokenize the prompt
118131
const llama_vocab * vocab = llama_model_get_vocab(model_bb);
119132
llama_tokens prompt_tokens = common_tokenize(vocab, params.prompt, false, true);
120133
prompt_tokens.insert(prompt_tokens.begin(), llama_vocab_bos(vocab));
121134
prompt_tokens.insert(prompt_tokens.end(), llama_vocab_eos(vocab));
122135

136+
// init sampler
137+
// the python implementation only has top-k and temperature sampling, so we'll use just that
138+
llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params()));
139+
llama_sampler_chain_add(sampler.get(), llama_sampler_init_top_k(params.sampling.top_k));
140+
llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(params.sampling.temp));
141+
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(params.sampling.seed));
142+
123143
printf("prompt tokens: \n");
124144
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
125145
printf("%d, ", prompt_tokens[i]);
@@ -176,7 +196,7 @@ int main(int argc, char ** argv) {
176196
// }
177197
// printf("\n");
178198

179-
llama_token semantic_tok = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc));
199+
llama_token semantic_tok = sample_token(sampler.get(), logits, llama_vocab_n_tokens(vocab_dc));
180200
printf("Sem token %5d : %d,", 1+(int)generated_codes.size()/32, semantic_tok);
181201
generated_codes.push_back(semantic_tok);
182202

@@ -227,7 +247,7 @@ int main(int argc, char ** argv) {
227247

228248
// sample the acoustic token
229249
auto logits = llama_get_logits_ith(ctx_dc, 0);
230-
llama_token acoustic_tok = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc));
250+
llama_token acoustic_tok = sample_token(sampler.get(), logits, llama_vocab_n_tokens(vocab_dc));
231251

232252
// discard last code (only for embeddings)
233253
if (i < n_codes - 1) {

0 commit comments

Comments
 (0)