1- #include  " arg.h" 
2- #include  " common.h" 
3- #include  " log.h" 
41#include  " llama.h" 
5- 
2+ #include  < cstdio> 
3+ #include  < cstring> 
4+ #include  < string> 
65#include  < vector> 
76
87static  void  print_usage (int , char  ** argv) {
9-     LOG (" \n example usage:\n "  );
10-     LOG (" \n     %s -m model.gguf -p  \" Hello my name is \"  -n 32 \n "  , argv[0 ]);
11-     LOG (" \n "  );
8+     printf (" \n example usage:\n "  );
9+     printf (" \n     %s -m model.gguf [-n n_predict] [-ngl n_gpu_layers] [prompt] \n "  , argv[0 ]);
10+     printf (" \n "  );
1211}
1312
1413int  main (int  argc, char  ** argv) {
15-     gpt_params params;
16- 
17-     params.prompt  = " Hello my name is"  ;
18-     params.n_predict  = 32 ;
19- 
20-     if  (!gpt_params_parse (argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
21-         return  1 ;
14+     //  path to the model gguf file
15+     std::string model_path;
16+     //  prompt to generate text from
17+     std::string prompt = " Hello my name is"  ;
18+     //  number of layers to offload to the GPU
19+     int  ngl = 99 ;
20+     //  number of tokens to predict
21+     int  n_predict = 32 ;
22+ 
23+     //  parse command line arguments
24+ 
25+     {
26+         int  i = 1 ;
27+         for  (; i < argc; i++) {
28+             if  (strcmp (argv[i], " -m"  ) == 0 ) {
29+                 if  (i + 1  < argc) {
30+                     model_path = argv[++i];
31+                 } else  {
32+                     print_usage (argc, argv);
33+                     return  1 ;
34+                 }
35+             } else  if  (strcmp (argv[i], " -n"  ) == 0 ) {
36+                 if  (i + 1  < argc) {
37+                     try  {
38+                         n_predict = std::stoi (argv[++i]);
39+                     } catch  (...) {
40+                         print_usage (argc, argv);
41+                         return  1 ;
42+                     }
43+                 } else  {
44+                     print_usage (argc, argv);
45+                     return  1 ;
46+                 }
47+             } else  if  (strcmp (argv[i], " -ngl"  ) == 0 ) {
48+                 if  (i + 1  < argc) {
49+                     try  {
50+                         ngl = std::stoi (argv[++i]);
51+                     } catch  (...) {
52+                         print_usage (argc, argv);
53+                         return  1 ;
54+                     }
55+                 } else  {
56+                     print_usage (argc, argv);
57+                     return  1 ;
58+                 }
59+             } else  {
60+                 //  prompt starts here
61+                 break ;
62+             }
63+         }
64+         if  (model_path.empty ()) {
65+             print_usage (argc, argv);
66+             return  1 ;
67+         }
68+         if  (i < argc) {
69+             prompt = argv[i++];
70+             for  (; i < argc; i++) {
71+                 prompt += "  "  ;
72+                 prompt += argv[i];
73+             }
74+         }
2275    }
2376
24-     gpt_init ();
25- 
26-     //  total length of the sequence including the prompt
27-     const  int  n_predict = params.n_predict ;
28- 
29-     //  init LLM
30- 
31-     llama_backend_init ();
32-     llama_numa_init (params.numa );
33- 
3477    //  initialize the model
3578
36-     llama_model_params model_params = llama_model_params_from_gpt_params (params);
79+     llama_model_params model_params = llama_model_default_params ();
80+     model_params.n_gpu_layers  = ngl;
3781
38-     llama_model * model = llama_load_model_from_file (params. model .c_str (), model_params);
82+     llama_model * model = llama_load_model_from_file (model_path .c_str (), model_params);
3983
4084    if  (model == NULL ) {
4185        fprintf (stderr , " %s: error: unable to load model\n "   , __func__);
4286        return  1 ;
4387    }
4488
89+     //  tokenize the prompt
90+ 
91+     //  find the number of tokens in the prompt
92+     const  int  n_prompt = -llama_tokenize (model, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
93+ 
94+     //  allocate space for the tokens and tokenize the prompt
95+     std::vector<llama_token> prompt_tokens (n_prompt);
96+     if  (llama_tokenize (model, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true , true ) < 0 ) {
97+         fprintf (stderr, " %s: error: failed to tokenize the prompt\n "  , __func__);
98+         return  1 ;
99+     }
100+ 
45101    //  initialize the context
46102
47-     llama_context_params ctx_params = llama_context_params_from_gpt_params (params);
103+     llama_context_params ctx_params = llama_context_default_params ();
104+     //  n_ctx is the context size
105+     ctx_params.n_ctx  = n_prompt + n_predict - 1 ;
106+     //  n_batch is the maximum number of tokens that can be processed in a single call to llama_decode
107+     ctx_params.n_batch  = n_prompt;
108+     //  enable performance counters
109+     ctx_params.no_perf  = false ;
48110
49111    llama_context * ctx = llama_new_context_with_model (model, ctx_params);
50112
@@ -53,117 +115,87 @@ int main(int argc, char ** argv) {
53115        return  1 ;
54116    }
55117
56-     auto  sparams =  llama_sampler_chain_default_params (); 
118+     //  initialize the sampler 
57119
120+     auto  sparams = llama_sampler_chain_default_params ();
58121    sparams.no_perf  = false ;
59- 
60122    llama_sampler * smpl = llama_sampler_chain_init (sparams);
61123
62124    llama_sampler_chain_add (smpl, llama_sampler_init_greedy ());
63125
64-     //  tokenize the prompt
65- 
66-     std::vector<llama_token> tokens_list;
67-     tokens_list = ::llama_tokenize (ctx, params.prompt , true );
68- 
69-     const  int  n_ctx    = llama_n_ctx (ctx);
70-     const  int  n_kv_req = tokens_list.size () + (n_predict - tokens_list.size ());
71- 
72-     LOG (" \n "  );
73-     LOG_INF (" %s: n_predict = %d, n_ctx = %d, n_kv_req = %d\n "  , __func__, n_predict, n_ctx, n_kv_req);
74- 
75-     //  make sure the KV cache is big enough to hold all the prompt and generated tokens
76-     if  (n_kv_req > n_ctx) {
77-         LOG_ERR (" %s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n "  , __func__);
78-         LOG_ERR (" %s:        either reduce n_predict or increase n_ctx\n "  , __func__);
79-         return  1 ;
80-     }
81- 
82126    //  print the prompt token-by-token
83127
84-     LOG (" \n "  );
85- 
86-     for  (auto  id : tokens_list) {
87-         LOG (" %s"  , llama_token_to_piece (ctx, id).c_str ());
88-     }
89- 
90-     //  create a llama_batch with size 512
91-     //  we use this object to submit token data for decoding
92- 
93-     llama_batch batch = llama_batch_init (512 , 0 , 1 );
94- 
95-     //  evaluate the initial prompt
96-     for  (size_t  i = 0 ; i < tokens_list.size (); i++) {
97-         llama_batch_add (batch, tokens_list[i], i, { 0  }, false );
128+     for  (auto  id : prompt_tokens) {
129+         char  buf[128 ];
130+         int  n = llama_token_to_piece (model, id, buf, sizeof (buf), 0 , true );
131+         if  (n < 0 ) {
132+             fprintf (stderr, " %s: error: failed to convert token to piece\n "  , __func__);
133+             return  1 ;
134+         }
135+         std::string s (buf, n);
136+         printf (" %s"  , s.c_str ());
98137    }
99138
100-     //  llama_decode will output logits only for the last token of the prompt
101-     batch.logits [batch.n_tokens  - 1 ] = true ;
139+     //  prepare a batch for the prompt
102140
103-     if  (llama_decode (ctx, batch) != 0 ) {
104-         LOG (" %s: llama_decode() failed\n "  , __func__);
105-         return  1 ;
106-     }
141+     llama_batch batch = llama_batch_get_one (prompt_tokens.data (), prompt_tokens.size (), 0 , 0 );
107142
108143    //  main loop
109144
110-     int  n_cur    = batch. n_tokens ;
145+     const   auto  t_main_start =  ggml_time_us () ;
111146    int  n_decode = 0 ;
147+     llama_token new_token_id;
112148
113-     const  auto  t_main_start = ggml_time_us ();
149+     for  (int  n_pos = 0 ; n_pos + batch.n_tokens  < n_prompt + n_predict; ) {
150+         //  evaluate the current batch with the transformer model
151+         if  (llama_decode (ctx, batch)) {
152+             fprintf (stderr, " %s : failed to eval, return code %d\n "  , __func__, 1 );
153+             return  1 ;
154+         }
155+ 
156+         n_pos += batch.n_tokens ;
114157
115-     while  (n_cur <= n_predict) {
116158        //  sample the next token
117159        {
118-             const  llama_token  new_token_id = llama_sampler_sample (smpl, ctx, -1 );
160+             new_token_id = llama_sampler_sample (smpl, ctx, -1 );
119161
120162            //  is it an end of generation?
121-             if  (llama_token_is_eog (model, new_token_id) || n_cur == n_predict) {
122-                 LOG (" \n "  );
123- 
163+             if  (llama_token_is_eog (model, new_token_id)) {
124164                break ;
125165            }
126166
127-             LOG (" %s"  , llama_token_to_piece (ctx, new_token_id).c_str ());
167+             char  buf[128 ];
168+             int  n = llama_token_to_piece (model, new_token_id, buf, sizeof (buf), 0 , true );
169+             if  (n < 0 ) {
170+                 fprintf (stderr, " %s: error: failed to convert token to piece\n "  , __func__);
171+                 return  1 ;
172+             }
173+             std::string s (buf, n);
174+             printf (" %s"  , s.c_str ());
128175            fflush (stdout);
129176
130-             //  prepare the next batch
131-             llama_batch_clear (batch);
132- 
133-             //  push this new token for next evaluation
134-             llama_batch_add (batch, new_token_id, n_cur, { 0  }, true );
177+             //  prepare the next batch with the sampled token
178+             batch = llama_batch_get_one (&new_token_id, 1 , n_pos, 0 );
135179
136180            n_decode += 1 ;
137181        }
138- 
139-         n_cur += 1 ;
140- 
141-         //  evaluate the current batch with the transformer model
142-         if  (llama_decode (ctx, batch)) {
143-             LOG_ERR (" %s : failed to eval, return code %d\n "  , __func__, 1 );
144-             return  1 ;
145-         }
146182    }
147183
148-     LOG (" \n "  );
184+     printf (" \n "  );
149185
150186    const  auto  t_main_end = ggml_time_us ();
151187
152-     LOG_INF ( " %s: decoded %d tokens in %.2f s, speed: %.2f t/s\n "  ,
188+     fprintf (stderr,  " %s: decoded %d tokens in %.2f s, speed: %.2f t/s\n "  ,
153189            __func__, n_decode, (t_main_end - t_main_start) / 1000000 .0f , n_decode / ((t_main_end - t_main_start) / 1000000 .0f ));
154190
155-     LOG ( " \n "  );
191+     fprintf (stderr,  " \n "  );
156192    llama_perf_sampler_print (smpl);
157193    llama_perf_context_print (ctx);
194+     fprintf (stderr, " \n "  );
158195
159-     LOG (" \n "  );
160- 
161-     llama_batch_free (batch);
162196    llama_sampler_free (smpl);
163197    llama_free (ctx);
164198    llama_free_model (model);
165199
166-     llama_backend_free ();
167- 
168200    return  0 ;
169201}
0 commit comments