1- #include " arg.h"
2- #include " common.h"
3- #include " log.h"
41#include " llama.h"
5-
2+ #include < cstdio>
3+ #include < cstring>
4+ #include < string>
65#include < vector>
76
87static void print_usage (int , char ** argv) {
9- LOG (" \n example usage:\n " );
10- LOG (" \n %s -m model.gguf -p \" Hello my name is \" -n 32 \n " , argv[0 ]);
11- LOG (" \n " );
8+ printf (" \n example usage:\n " );
9+ printf (" \n %s -m model.gguf [-n n_predict] [-ngl n_gpu_layers] [prompt] \n " , argv[0 ]);
10+ printf (" \n " );
1211}
1312
1413int main (int argc, char ** argv) {
15- gpt_params params;
16-
17- params.prompt = " Hello my name is" ;
18- params.n_predict = 32 ;
19-
20- if (!gpt_params_parse (argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
21- return 1 ;
14+ // path to the model gguf file
15+ std::string model_path;
16+ // prompt to generate text from
17+ std::string prompt = " Hello my name is" ;
18+ // number of layers to offload to the GPU
19+ int ngl = 99 ;
20+ // number of tokens to predict
21+ int n_predict = 32 ;
22+
23+ // parse command line arguments
24+
25+ {
26+ int i = 1 ;
27+ for (; i < argc; i++) {
28+ if (strcmp (argv[i], " -m" ) == 0 ) {
29+ if (i + 1 < argc) {
30+ model_path = argv[++i];
31+ } else {
32+ print_usage (argc, argv);
33+ return 1 ;
34+ }
35+ } else if (strcmp (argv[i], " -n" ) == 0 ) {
36+ if (i + 1 < argc) {
37+ try {
38+ n_predict = std::stoi (argv[++i]);
39+ } catch (...) {
40+ print_usage (argc, argv);
41+ return 1 ;
42+ }
43+ } else {
44+ print_usage (argc, argv);
45+ return 1 ;
46+ }
47+ } else if (strcmp (argv[i], " -ngl" ) == 0 ) {
48+ if (i + 1 < argc) {
49+ try {
50+ ngl = std::stoi (argv[++i]);
51+ } catch (...) {
52+ print_usage (argc, argv);
53+ return 1 ;
54+ }
55+ } else {
56+ print_usage (argc, argv);
57+ return 1 ;
58+ }
59+ } else {
60+ // prompt starts here
61+ break ;
62+ }
63+ }
64+ if (model_path.empty ()) {
65+ print_usage (argc, argv);
66+ return 1 ;
67+ }
68+ if (i < argc) {
69+ prompt = argv[i++];
70+ for (; i < argc; i++) {
71+ prompt += " " ;
72+ prompt += argv[i];
73+ }
74+ }
2275 }
2376
24- gpt_init ();
25-
26- // total length of the sequence including the prompt
27- const int n_predict = params.n_predict ;
28-
29- // init LLM
30-
31- llama_backend_init ();
32- llama_numa_init (params.numa );
33-
3477 // initialize the model
3578
36- llama_model_params model_params = llama_model_params_from_gpt_params (params);
79+ llama_model_params model_params = llama_model_default_params ();
80+ model_params.n_gpu_layers = ngl;
3781
38- llama_model * model = llama_load_model_from_file (params. model .c_str (), model_params);
82+ llama_model * model = llama_load_model_from_file (model_path .c_str (), model_params);
3983
4084 if (model == NULL ) {
4185 fprintf (stderr , " %s: error: unable to load model\n " , __func__);
4286 return 1 ;
4387 }
4488
89+ // tokenize the prompt
90+
91+ // find the number of tokens in the prompt
92+ const int n_prompt = -llama_tokenize (model, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
93+
94+ // allocate space for the tokens and tokenize the prompt
95+ std::vector<llama_token> prompt_tokens (n_prompt);
96+ if (llama_tokenize (model, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true , true ) < 0 ) {
97+ fprintf (stderr, " %s: error: failed to tokenize the prompt\n " , __func__);
98+ return 1 ;
99+ }
100+
45101 // initialize the context
46102
47- llama_context_params ctx_params = llama_context_params_from_gpt_params (params);
103+ llama_context_params ctx_params = llama_context_default_params ();
104+ // n_ctx is the context size
105+ ctx_params.n_ctx = n_prompt + n_predict - 1 ;
106+ // n_batch is the maximum number of tokens that can be processed in a single call to llama_decode
107+ ctx_params.n_batch = n_prompt;
108+ // enable performance counters
109+ ctx_params.no_perf = false ;
48110
49111 llama_context * ctx = llama_new_context_with_model (model, ctx_params);
50112
@@ -53,117 +115,87 @@ int main(int argc, char ** argv) {
53115 return 1 ;
54116 }
55117
56- auto sparams = llama_sampler_chain_default_params ();
118+ // initialize the sampler
57119
120+ auto sparams = llama_sampler_chain_default_params ();
58121 sparams.no_perf = false ;
59-
60122 llama_sampler * smpl = llama_sampler_chain_init (sparams);
61123
62124 llama_sampler_chain_add (smpl, llama_sampler_init_greedy ());
63125
64- // tokenize the prompt
65-
66- std::vector<llama_token> tokens_list;
67- tokens_list = ::llama_tokenize (ctx, params.prompt , true );
68-
69- const int n_ctx = llama_n_ctx (ctx);
70- const int n_kv_req = tokens_list.size () + (n_predict - tokens_list.size ());
71-
72- LOG (" \n " );
73- LOG_INF (" %s: n_predict = %d, n_ctx = %d, n_kv_req = %d\n " , __func__, n_predict, n_ctx, n_kv_req);
74-
75- // make sure the KV cache is big enough to hold all the prompt and generated tokens
76- if (n_kv_req > n_ctx) {
77- LOG_ERR (" %s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n " , __func__);
78- LOG_ERR (" %s: either reduce n_predict or increase n_ctx\n " , __func__);
79- return 1 ;
80- }
81-
82126 // print the prompt token-by-token
83127
84- LOG (" \n " );
85-
86- for (auto id : tokens_list) {
87- LOG (" %s" , llama_token_to_piece (ctx, id).c_str ());
88- }
89-
90- // create a llama_batch with size 512
91- // we use this object to submit token data for decoding
92-
93- llama_batch batch = llama_batch_init (512 , 0 , 1 );
94-
95- // evaluate the initial prompt
96- for (size_t i = 0 ; i < tokens_list.size (); i++) {
97- llama_batch_add (batch, tokens_list[i], i, { 0 }, false );
128+ for (auto id : prompt_tokens) {
129+ char buf[128 ];
130+ int n = llama_token_to_piece (model, id, buf, sizeof (buf), 0 , true );
131+ if (n < 0 ) {
132+ fprintf (stderr, " %s: error: failed to convert token to piece\n " , __func__);
133+ return 1 ;
134+ }
135+ std::string s (buf, n);
136+ printf (" %s" , s.c_str ());
98137 }
99138
100- // llama_decode will output logits only for the last token of the prompt
101- batch.logits [batch.n_tokens - 1 ] = true ;
139+ // prepare a batch for the prompt
102140
103- if (llama_decode (ctx, batch) != 0 ) {
104- LOG (" %s: llama_decode() failed\n " , __func__);
105- return 1 ;
106- }
141+ llama_batch batch = llama_batch_get_one (prompt_tokens.data (), prompt_tokens.size (), 0 , 0 );
107142
108143 // main loop
109144
110- int n_cur = batch. n_tokens ;
145+ const auto t_main_start = ggml_time_us () ;
111146 int n_decode = 0 ;
147+ llama_token new_token_id;
112148
113- const auto t_main_start = ggml_time_us ();
149+ for (int n_pos = 0 ; n_pos + batch.n_tokens < n_prompt + n_predict; ) {
150+ // evaluate the current batch with the transformer model
151+ if (llama_decode (ctx, batch)) {
152+ fprintf (stderr, " %s : failed to eval, return code %d\n " , __func__, 1 );
153+ return 1 ;
154+ }
155+
156+ n_pos += batch.n_tokens ;
114157
115- while (n_cur <= n_predict) {
116158 // sample the next token
117159 {
118- const llama_token new_token_id = llama_sampler_sample (smpl, ctx, -1 );
160+ new_token_id = llama_sampler_sample (smpl, ctx, -1 );
119161
120162 // is it an end of generation?
121- if (llama_token_is_eog (model, new_token_id) || n_cur == n_predict) {
122- LOG (" \n " );
123-
163+ if (llama_token_is_eog (model, new_token_id)) {
124164 break ;
125165 }
126166
127- LOG (" %s" , llama_token_to_piece (ctx, new_token_id).c_str ());
167+ char buf[128 ];
168+ int n = llama_token_to_piece (model, new_token_id, buf, sizeof (buf), 0 , true );
169+ if (n < 0 ) {
170+ fprintf (stderr, " %s: error: failed to convert token to piece\n " , __func__);
171+ return 1 ;
172+ }
173+ std::string s (buf, n);
174+ printf (" %s" , s.c_str ());
128175 fflush (stdout);
129176
130- // prepare the next batch
131- llama_batch_clear (batch);
132-
133- // push this new token for next evaluation
134- llama_batch_add (batch, new_token_id, n_cur, { 0 }, true );
177+ // prepare the next batch with the sampled token
178+ batch = llama_batch_get_one (&new_token_id, 1 , n_pos, 0 );
135179
136180 n_decode += 1 ;
137181 }
138-
139- n_cur += 1 ;
140-
141- // evaluate the current batch with the transformer model
142- if (llama_decode (ctx, batch)) {
143- LOG_ERR (" %s : failed to eval, return code %d\n " , __func__, 1 );
144- return 1 ;
145- }
146182 }
147183
148- LOG (" \n " );
184+ printf (" \n " );
149185
150186 const auto t_main_end = ggml_time_us ();
151187
152- LOG_INF ( " %s: decoded %d tokens in %.2f s, speed: %.2f t/s\n " ,
188+ fprintf (stderr, " %s: decoded %d tokens in %.2f s, speed: %.2f t/s\n " ,
153189 __func__, n_decode, (t_main_end - t_main_start) / 1000000 .0f , n_decode / ((t_main_end - t_main_start) / 1000000 .0f ));
154190
155- LOG ( " \n " );
191+ fprintf (stderr, " \n " );
156192 llama_perf_sampler_print (smpl);
157193 llama_perf_context_print (ctx);
194+ fprintf (stderr, " \n " );
158195
159- LOG (" \n " );
160-
161- llama_batch_free (batch);
162196 llama_sampler_free (smpl);
163197 llama_free (ctx);
164198 llama_free_model (model);
165199
166- llama_backend_free ();
167-
168200 return 0 ;
169201}
0 commit comments