@@ -70,6 +70,7 @@ struct mtmd_cli_context {
7070 llama_model * model;
7171 llama_context * lctx;
7272 const llama_vocab * vocab;
73+ common_sampler * smpl;
7374 llama_batch batch;
7475 int n_batch;
7576
@@ -89,8 +90,9 @@ struct mtmd_cli_context {
8990 model = llama_init.model .get ();
9091 lctx = llama_init.context .get ();
9192 vocab = llama_model_get_vocab (model);
93+ smpl = common_sampler_init (model, params.sampling );
9294 n_threads = params.cpuparams .n_threads ;
93- batch = llama_batch_init (params. n_batch , 0 , 1 );
95+ batch = llama_batch_init (1 , 0 , 1 ); // batch for next token generation
9496 n_batch = params.n_batch ;
9597
9698 if (!model || !lctx) {
@@ -118,6 +120,11 @@ struct mtmd_cli_context {
118120 }
119121 }
120122
123+ ~mtmd_cli_context () {
124+ llama_batch_free (batch);
125+ common_sampler_free (smpl);
126+ }
127+
121128 void init_vision_context (common_params & params) {
122129 const char * clip_path = params.mmproj .path .c_str ();
123130 mtmd_context_params mparams = mtmd_context_params_default ();
@@ -153,17 +160,17 @@ struct mtmd_cli_context {
153160 }
154161};
155162
156- static int generate_response (mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
163+ static int generate_response (mtmd_cli_context & ctx, int n_predict) {
157164 llama_tokens generated_tokens;
158165 for (int i = 0 ; i < n_predict; i++) {
159166 if (i > n_predict || !g_is_generating || g_is_interrupted) {
160167 LOG (" \n " );
161168 break ;
162169 }
163170
164- llama_token token_id = common_sampler_sample (smpl, ctx.lctx , -1 );
171+ llama_token token_id = common_sampler_sample (ctx. smpl , ctx.lctx , -1 );
165172 generated_tokens.push_back (token_id);
166- common_sampler_accept (smpl, token_id, true );
173+ common_sampler_accept (ctx. smpl , token_id, true );
167174
168175 if (llama_vocab_is_eog (ctx.vocab , token_id) || ctx.check_antiprompt (generated_tokens)) {
169176 LOG (" \n " );
@@ -261,7 +268,6 @@ int main(int argc, char ** argv) {
261268
262269 bool is_single_turn = !params.prompt .empty () && !params.image .empty ();
263270
264- struct common_sampler * smpl = common_sampler_init (ctx.model , params.sampling );
265271 int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict ;
266272
267273 // Ctrl+C handling
@@ -300,7 +306,7 @@ int main(int argc, char ** argv) {
300306 if (eval_message (ctx, msg, true )) {
301307 return 1 ;
302308 }
303- if (!g_is_interrupted && generate_response (ctx, smpl, n_predict)) {
309+ if (!g_is_interrupted && generate_response (ctx, n_predict)) {
304310 return 1 ;
305311 }
306312
@@ -366,7 +372,7 @@ int main(int argc, char ** argv) {
366372 return 1 ;
367373 }
368374 if (g_is_interrupted) break ;
369- if (generate_response (ctx, smpl, n_predict)) {
375+ if (generate_response (ctx, n_predict)) {
370376 return 1 ;
371377 }
372378 content.clear ();
0 commit comments