@@ -29,17 +29,27 @@ static void print_usage(int, char ** argv) {
2929 LOG (" \n " );
3030}
3131
32- // greedy sampling with custom n_vocab
33- static llama_token sample_greedy (const float * logits, int n_vocab) {
34- llama_token max_idx = -1 ;
35- float max_val = -FLT_MAX;
36- for (int i = 0 ; i < n_vocab; ++i) {
37- if (logits[i] > max_val) {
38- max_val = logits[i];
39- max_idx = i;
40- }
32+ // sampling with custom n_vocab
33+ // modified version of llama_sampler_sample()
34+ static llama_token sample_token (struct llama_sampler * smpl, const float * logits, int n_vocab) {
35+ std::vector<llama_token_data> cur;
36+ cur.reserve (n_vocab);
37+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
38+ cur.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
4139 }
42- return max_idx;
40+
41+ llama_token_data_array cur_p = {
42+ /* .data = */ cur.data (),
43+ /* .size = */ cur.size (),
44+ /* .selected = */ -1 ,
45+ /* .sorted = */ false ,
46+ };
47+
48+ llama_sampler_apply (smpl, &cur_p);
49+ GGML_ASSERT (cur_p.selected >= 0 && cur_p.selected < (int32_t ) cur_p.size );
50+ auto token = cur_p.data [cur_p.selected ].id ;
51+ llama_sampler_accept (smpl, token);
52+ return token;
4353}
4454
4555// hook to retrieve the embeddings
@@ -63,11 +73,13 @@ static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) {
6373int main (int argc, char ** argv) {
6474 common_params params;
6575
66- params.model = " sesame-csm-backbone.gguf" ;
67- params.vocoder .model = " kyutai-mimi.gguf" ;
68- params.out_file = " output.wav" ;
69- params.prompt = " " ;
70- params.n_predict = 2048 ; // CSM's max trained seq length
76+ params.model = " sesame-csm-backbone.gguf" ;
77+ params.vocoder .model = " kyutai-mimi.gguf" ;
78+ params.out_file = " output.wav" ;
79+ params.prompt = " " ;
80+ params.n_predict = 2048 ; // CSM's max trained seq length
81+ params.sampling .top_k = 50 ; // default param from CSM python code
82+ params.sampling .temp = 0.9 ; // default param from CSM python code
7183
7284 // HF model
7385 params.model_url = " https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-backbone.gguf" ;
@@ -115,11 +127,19 @@ int main(int argc, char ** argv) {
115127
116128 mimi_model mimi (params.vocoder .model .c_str (), true );
117129
130+ // tokenize the prompt
118131 const llama_vocab * vocab = llama_model_get_vocab (model_bb);
119132 llama_tokens prompt_tokens = common_tokenize (vocab, params.prompt , false , true );
120133 prompt_tokens.insert (prompt_tokens.begin (), llama_vocab_bos (vocab));
121134 prompt_tokens.insert (prompt_tokens.end (), llama_vocab_eos (vocab));
122135
136+ // init sampler
137+ // the python implementation only has top-k and temperature sampling, so we'll use just that
138+ llama_sampler_ptr sampler (llama_sampler_chain_init (llama_sampler_chain_default_params ()));
139+ llama_sampler_chain_add (sampler.get (), llama_sampler_init_top_k (params.sampling .top_k ));
140+ llama_sampler_chain_add (sampler.get (), llama_sampler_init_temp (params.sampling .temp ));
141+ llama_sampler_chain_add (sampler.get (), llama_sampler_init_dist (params.sampling .seed ));
142+
123143 printf (" prompt tokens: \n " );
124144 for (size_t i = 0 ; i < prompt_tokens.size (); ++i) {
125145 printf (" %d, " , prompt_tokens[i]);
@@ -176,7 +196,7 @@ int main(int argc, char ** argv) {
176196 // }
177197 // printf("\n");
178198
179- llama_token semantic_tok = sample_greedy ( logits, llama_vocab_n_tokens (vocab_dc));
199+ llama_token semantic_tok = sample_token (sampler. get (), logits, llama_vocab_n_tokens (vocab_dc));
180200 printf (" Sem token %5d : %d," , 1 +(int )generated_codes.size ()/32 , semantic_tok);
181201 generated_codes.push_back (semantic_tok);
182202
@@ -227,7 +247,7 @@ int main(int argc, char ** argv) {
227247
228248 // sample the acoustic token
229249 auto logits = llama_get_logits_ith (ctx_dc, 0 );
230- llama_token acoustic_tok = sample_greedy ( logits, llama_vocab_n_tokens (vocab_dc));
250+ llama_token acoustic_tok = sample_token (sampler. get (), logits, llama_vocab_n_tokens (vocab_dc));
231251
232252 // discard last code (only for embeddings)
233253 if (i < n_codes - 1 ) {
0 commit comments