@@ -70,6 +70,7 @@ struct mtmd_cli_context {
7070    llama_model       * model;
7171    llama_context     * lctx;
7272    const  llama_vocab * vocab;
73+     common_sampler    * smpl;
7374    llama_batch         batch;
7475    int                  n_batch;
7576
@@ -89,8 +90,9 @@ struct mtmd_cli_context {
8990        model = llama_init.model .get ();
9091        lctx = llama_init.context .get ();
9192        vocab = llama_model_get_vocab (model);
93+         smpl = common_sampler_init (model, params.sampling );
9294        n_threads = params.cpuparams .n_threads ;
93-         batch = llama_batch_init (params. n_batch , 0 , 1 );
95+         batch = llama_batch_init (1 , 0 , 1 );  //  batch for next token generation 
9496        n_batch = params.n_batch ;
9597
9698        if  (!model || !lctx) {
@@ -118,6 +120,11 @@ struct mtmd_cli_context {
118120        }
119121    }
120122
123+     ~mtmd_cli_context () {
124+         llama_batch_free (batch);
125+         common_sampler_free (smpl);
126+     }
127+ 
121128    void  init_vision_context (common_params & params) {
122129        const  char  * clip_path = params.mmproj .path .c_str ();
123130        mtmd_context_params mparams = mtmd_context_params_default ();
@@ -153,17 +160,17 @@ struct mtmd_cli_context {
153160    }
154161};
155162
156- static  int  generate_response (mtmd_cli_context & ctx, common_sampler * smpl,  int  n_predict) {
163+ static  int  generate_response (mtmd_cli_context & ctx, int  n_predict) {
157164    llama_tokens generated_tokens;
158165    for  (int  i = 0 ; i < n_predict; i++) {
159166        if  (i > n_predict || !g_is_generating || g_is_interrupted) {
160167            LOG (" \n "  );
161168            break ;
162169        }
163170
164-         llama_token token_id = common_sampler_sample (smpl, ctx.lctx , -1 );
171+         llama_token token_id = common_sampler_sample (ctx. smpl , ctx.lctx , -1 );
165172        generated_tokens.push_back (token_id);
166-         common_sampler_accept (smpl, token_id, true );
173+         common_sampler_accept (ctx. smpl , token_id, true );
167174
168175        if  (llama_vocab_is_eog (ctx.vocab , token_id) || ctx.check_antiprompt (generated_tokens)) {
169176            LOG (" \n "  );
@@ -261,7 +268,6 @@ int main(int argc, char ** argv) {
261268
262269    bool  is_single_turn = !params.prompt .empty () && !params.image .empty ();
263270
264-     struct  common_sampler  * smpl = common_sampler_init (ctx.model , params.sampling );
265271    int  n_predict = params.n_predict  < 0  ? INT_MAX : params.n_predict ;
266272
267273    //  Ctrl+C handling
@@ -300,7 +306,7 @@ int main(int argc, char ** argv) {
300306        if  (eval_message (ctx, msg, true )) {
301307            return  1 ;
302308        }
303-         if  (!g_is_interrupted && generate_response (ctx, smpl,  n_predict)) {
309+         if  (!g_is_interrupted && generate_response (ctx, n_predict)) {
304310            return  1 ;
305311        }
306312
@@ -366,7 +372,7 @@ int main(int argc, char ** argv) {
366372                return  1 ;
367373            }
368374            if  (g_is_interrupted) break ;
369-             if  (generate_response (ctx, smpl,  n_predict)) {
375+             if  (generate_response (ctx, n_predict)) {
370376                return  1 ;
371377            }
372378            content.clear ();
0 commit comments