1- #include  " arg.h" 
2- #include  " common.h" 
3- #include  " log.h" 
41#include  " llama.h" 
5- 
2+ #include  < cstdio> 
3+ #include  < string> 
64#include  < vector> 
75
86static  void  print_usage (int , char  ** argv) {
9-     LOG (" \n example usage:\n "  );
10-     LOG (" \n     %s -m  model.gguf -p  \" Hello my name is \"  -n 32 \n "  , argv[0 ]);
11-     LOG (" \n "  );
7+     printf (" \n example usage:\n "  );
8+     printf (" \n     %s < model.gguf> [prompt] \n "  , argv[0 ]);
9+     printf (" \n "  );
1210}
1311
1412int  main (int  argc, char  ** argv) {
15-     gpt_params params;
16- 
17-     params.prompt  = " Hello my name is"  ;
18-     params.n_predict  = 32 ;
13+     std::string model_path;
14+     std::string prompt = " Hello my name is"  ;
15+     int  n_predict = 32 ;
1916
20-     if  (!gpt_params_parse (argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
17+     if  (argc < 2 ) {
18+         print_usage (argc, argv);
2119        return  1 ;
2220    }
21+     model_path = argv[1 ];
2322
24-     gpt_init ();
25- 
26-     //  total length of the sequence including the prompt
27-     const  int  n_predict = params.n_predict ;
28- 
29-     //  init LLM
30- 
31-     llama_backend_init ();
32-     llama_numa_init (params.numa );
23+     if  (argc > 2 ) {
24+         prompt = argv[2 ];
25+         for  (int  i = 3 ; i < argc; i++) {
26+             prompt += "  "  ;
27+             prompt += argv[i];
28+         }
29+     }
3330
3431    //  initialize the model
3532
36-     llama_model_params model_params = llama_model_params_from_gpt_params (params );
37- 
38-     llama_model * model = llama_load_model_from_file (params. model .c_str (), model_params);
33+     llama_model_params model_params = llama_model_default_params ( );
34+     model_params. n_gpu_layers  =  99 ;  //  offload all layers to GPU 
35+     llama_model * model = llama_load_model_from_file (model_path .c_str (), model_params);
3936
4037    if  (model == NULL ) {
4138        fprintf (stderr , " %s: error: unable to load model\n "   , __func__);
@@ -44,8 +41,9 @@ int main(int argc, char ** argv) {
4441
4542    //  initialize the context
4643
47-     llama_context_params ctx_params = llama_context_params_from_gpt_params (params);
48- 
44+     llama_context_params ctx_params = llama_context_default_params ();
45+     ctx_params.n_ctx  = 512 ; //  maximum context size
46+     ctx_params.no_perf  = false ;
4947    llama_context * ctx = llama_new_context_with_model (model, ctx_params);
5048
5149    if  (ctx == NULL ) {
@@ -54,54 +52,58 @@ int main(int argc, char ** argv) {
5452    }
5553
5654    auto  sparams = llama_sampler_chain_default_params ();
57- 
5855    sparams.no_perf  = false ;
59- 
6056    llama_sampler * smpl = llama_sampler_chain_init (sparams);
6157
6258    llama_sampler_chain_add (smpl, llama_sampler_init_greedy ());
6359
6460    //  tokenize the prompt
6561
6662    std::vector<llama_token> tokens_list;
67-     tokens_list = ::llama_tokenize (ctx, params.prompt , true );
63+     int  n_tokens = llama_tokenize (model, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
64+     tokens_list.resize (-n_tokens);
65+     if  (llama_tokenize (model, prompt.c_str (), prompt.size (), tokens_list.data (), tokens_list.size (), true , true ) < 0 ) {
66+         fprintf (stderr, " %s: error: failed to tokenize the prompt\n "  , __func__);
67+         return  1 ;
68+     }
6869
6970    const  int  n_ctx    = llama_n_ctx (ctx);
7071    const  int  n_kv_req = tokens_list.size () + (n_predict - tokens_list.size ());
7172
72-     LOG ( " \n "  );
73-      LOG_INF ( " %s: n_predict = %d, n_ctx = %d, n_kv_req = %d \n " , __func__, n_predict, n_ctx, n_kv_req); 
73+     fprintf (stderr,  " %s: n_predict = %d, n_ctx = %d, n_kv_req = %d \n " , __func__, n_predict, n_ctx, n_kv_req );
74+ 
7475
7576    //  make sure the KV cache is big enough to hold all the prompt and generated tokens
7677    if  (n_kv_req > n_ctx) {
77-         LOG_ERR ( " %s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n "  , __func__);
78-         LOG_ERR ( " %s:        either reduce n_predict or increase n_ctx\n "  , __func__);
78+         fprintf (stderr,  " %s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n "  , __func__);
79+         fprintf (stderr,  " %s:        either reduce n_predict or increase n_ctx\n "  , __func__);
7980        return  1 ;
8081    }
8182
8283    //  print the prompt token-by-token
8384
84-     LOG ( " \n "  );
85+     fprintf (stderr,  " \n "  );
8586
8687    for  (auto  id : tokens_list) {
87-         LOG (" %s"  , llama_token_to_piece (ctx, id).c_str ());
88+         char  buf[128 ];
89+         int  n = llama_token_to_piece (model, id, buf, sizeof (buf), 0 , true );
90+         if  (n < 0 ) {
91+             fprintf (stderr, " %s: error: failed to convert token to piece\n "  , __func__);
92+             return  1 ;
93+         }
94+         std::string s (buf, n);
95+         printf (" %s"  , s.c_str ());
8896    }
8997
9098    //  create a llama_batch with size 512
9199    //  we use this object to submit token data for decoding
92100
93-     llama_batch batch = llama_batch_init ( 512 ,  0 , 1 );
101+     llama_batch batch = llama_batch_get_one (tokens_list. data (), tokens_list. size (),  0 , 0 );
94102
95103    //  evaluate the initial prompt
96-     for  (size_t  i = 0 ; i < tokens_list.size (); i++) {
97-         llama_batch_add (batch, tokens_list[i], i, { 0  }, false );
98-     }
99- 
100-     //  llama_decode will output logits only for the last token of the prompt
101-     batch.logits [batch.n_tokens  - 1 ] = true ;
102104
103105    if  (llama_decode (ctx, batch) != 0 ) {
104-         LOG ( " %s: llama_decode() failed\n "  , __func__);
106+         fprintf (stderr,  " %s: llama_decode() failed\n "  , __func__);
105107        return  1 ;
106108    }
107109
@@ -114,24 +116,28 @@ int main(int argc, char ** argv) {
114116
115117    while  (n_cur <= n_predict) {
116118        //  sample the next token
119+         llama_token new_token_id = llama_sampler_sample (smpl, ctx, -1 );
117120        {
118-             const  llama_token new_token_id = llama_sampler_sample (smpl, ctx, -1 );
119121
120122            //  is it an end of generation?
121123            if  (llama_token_is_eog (model, new_token_id) || n_cur == n_predict) {
122-                 LOG ( " \n "  );
124+                 fprintf (stderr,  " \n "  );
123125
124126                break ;
125127            }
126128
127-             LOG (" %s"  , llama_token_to_piece (ctx, new_token_id).c_str ());
129+             char  buf[128 ];
130+             int  n = llama_token_to_piece (model, new_token_id, buf, sizeof (buf), 0 , true );
131+             if  (n < 0 ) {
132+                 fprintf (stderr, " %s: error: failed to convert token to piece\n "  , __func__);
133+                 return  1 ;
134+             }
135+             std::string s (buf, n);
136+             printf (" %s"  , s.c_str ());
128137            fflush (stdout);
129138
130139            //  prepare the next batch
131-             llama_batch_clear (batch);
132- 
133-             //  push this new token for next evaluation
134-             llama_batch_add (batch, new_token_id, n_cur, { 0  }, true );
140+             batch = llama_batch_get_one (&new_token_id, 1 , n_cur, 0 );
135141
136142            n_decode += 1 ;
137143        }
@@ -140,30 +146,26 @@ int main(int argc, char ** argv) {
140146
141147        //  evaluate the current batch with the transformer model
142148        if  (llama_decode (ctx, batch)) {
143-             LOG_ERR ( " %s : failed to eval, return code %d\n "  , __func__, 1 );
149+             fprintf (stderr,  " %s : failed to eval, return code %d\n "  , __func__, 1 );
144150            return  1 ;
145151        }
146152    }
147153
148-     LOG ( " \n "  );
154+     fprintf (stderr,  " \n "  );
149155
150156    const  auto  t_main_end = ggml_time_us ();
151157
152-     LOG_INF ( " %s: decoded %d tokens in %.2f s, speed: %.2f t/s\n "  ,
158+     fprintf (stderr,  " %s: decoded %d tokens in %.2f s, speed: %.2f t/s\n "  ,
153159            __func__, n_decode, (t_main_end - t_main_start) / 1000000 .0f , n_decode / ((t_main_end - t_main_start) / 1000000 .0f ));
154160
155-     LOG ( " \n "  );
161+     fprintf (stderr,  " \n "  );
156162    llama_perf_sampler_print (smpl);
157163    llama_perf_context_print (ctx);
164+     fprintf (stderr, " \n "  );
158165
159-     LOG (" \n "  );
160- 
161-     llama_batch_free (batch);
162166    llama_sampler_free (smpl);
163167    llama_free (ctx);
164168    llama_free_model (model);
165169
166-     llama_backend_free ();
167- 
168170    return  0 ;
169171}
0 commit comments