66#include " sampling.h"
77#include " log.h"
88#include " llama.h"
9+ #include " llama-cpp.h"
910
1011#include < cmath>
1112#include < cstdio>
@@ -174,7 +175,7 @@ int main(int argc, char ** argv) {
174175
175176 // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
176177 // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
177- llama_batch_ext * batch = llama_batch_ext_init (n_ctx, 1 );
178+ llama_batch_ext_ptr batch ( llama_batch_ext_init (n_ctx, 1 ) );
178179
179180 int32_t n_total_prompt = 0 ;
180181 int32_t n_total_gen = 0 ;
@@ -192,11 +193,10 @@ int main(int argc, char ** argv) {
192193 LOG_INF (" %s: Evaluating the system prompt ...\n " , __func__);
193194
194195 for (int32_t i = 0 ; i < n_tokens_system; ++i) {
195- llama_seq_id seq_id = 0 ;
196- llama_batch_ext_add_text (batch, tokens_system[i], i, &seq_id, 1 , false );
196+ batch.add_text (tokens_system[i], i, 0 , false );
197197 }
198198
199- if (llama_decode_ext (ctx, batch) != 0 ) {
199+ if (llama_decode_ext (ctx, batch. get () ) != 0 ) {
200200 LOG_ERR (" %s: llama_decode() failed\n " , __func__);
201201 return 1 ;
202202 }
@@ -217,23 +217,23 @@ int main(int argc, char ** argv) {
217217 common_kv_cache_dump_view_seqs (kvc_view, 40 );
218218 }
219219
220- llama_batch_ext_clear (batch);
220+ llama_batch_ext_clear (batch. get () );
221221
222222 // decode any currently ongoing sequences
223223 for (auto & client : clients) {
224224 if (client.seq_id == -1 ) {
225225 continue ;
226226 }
227227
228- client.i_batch = llama_batch_ext_get_n_tokens (batch);
228+ client.i_batch = llama_batch_ext_get_n_tokens (batch. get () );
229229
230230 llama_seq_id seq_id = client.id + 1 ;
231- llama_batch_ext_add_text ( batch, client.sampled , n_tokens_system + client.n_prompt + client.n_decoded , & seq_id, 1 , true );
231+ batch. add_text ( client.sampled , n_tokens_system + client.n_prompt + client.n_decoded , seq_id, true );
232232
233233 client.n_decoded += 1 ;
234234 }
235235
236- if (llama_batch_ext_get_n_tokens (batch) == 0 ) {
236+ if (llama_batch_ext_get_n_tokens (batch. get () ) == 0 ) {
237237 // all sequences have ended - clear the entire KV cache
238238 for (int i = 1 ; i <= n_clients; ++i) {
239239 llama_kv_self_seq_rm (ctx, i, -1 , -1 );
@@ -245,7 +245,7 @@ int main(int argc, char ** argv) {
245245 }
246246
247247 // insert new sequences for decoding
248- if (cont_batching || llama_batch_ext_get_n_tokens (batch) == 0 ) {
248+ if (cont_batching || llama_batch_ext_get_n_tokens (batch. get () ) == 0 ) {
249249 for (auto & client : clients) {
250250 if (client.seq_id == -1 && g_seq_id < n_seq) {
251251 client.seq_id = g_seq_id;
@@ -265,17 +265,17 @@ int main(int argc, char ** argv) {
265265
266266 for (size_t i = 0 ; i < tokens_prompt.size (); ++i) {
267267 llama_seq_id seq_id = client.id + 1 ;
268- llama_batch_ext_add_text ( batch, tokens_prompt[i], i + n_tokens_system, & seq_id, 1 , false );
268+ batch. add_text ( tokens_prompt[i], i + n_tokens_system, seq_id, false );
269269 }
270270
271271 // extract the logits only for the last token
272- if (llama_batch_ext_get_n_tokens (batch) > 0 ) {
273- llama_batch_ext_set_output_last (batch);
272+ if (llama_batch_ext_get_n_tokens (batch. get () ) > 0 ) {
273+ llama_batch_ext_set_output_last (batch. get () );
274274 }
275275
276276 client.n_prompt = tokens_prompt.size ();
277277 client.n_decoded = 0 ;
278- client.i_batch = llama_batch_ext_get_n_tokens (batch) - 1 ;
278+ client.i_batch = llama_batch_ext_get_n_tokens (batch. get () ) - 1 ;
279279
280280 LOG_INF (" \033 [31mClient %3d, seq %4d, started decoding ...\033 [0m\n " , client.id , client.seq_id );
281281
@@ -289,14 +289,14 @@ int main(int argc, char ** argv) {
289289 }
290290 }
291291
292- if (llama_batch_ext_get_n_tokens (batch) == 0 ) {
292+ if (llama_batch_ext_get_n_tokens (batch. get () ) == 0 ) {
293293 break ;
294294 }
295295
296296 // process in chunks of params.n_batch
297297 int32_t n_batch = params.n_batch ;
298298
299- int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens (batch);
299+ int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens (batch. get () );
300300 for (int32_t i = 0 ; i < (int32_t ) n_tokens_in_batch; i += n_batch) {
301301 // experiment: process in powers of 2
302302 // if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
@@ -307,7 +307,7 @@ int main(int argc, char ** argv) {
307307
308308 const int32_t n_tokens = std::min (n_batch, (int32_t ) (n_tokens_in_batch - i));
309309
310- llama_batch_ext * batch_view = llama_batch_ext_get_view (batch, i, n_tokens);
310+ llama_batch_ext * batch_view = llama_batch_ext_get_view (batch. get () , i, n_tokens);
311311 const int ret = llama_decode_ext (ctx, batch_view);
312312 llama_batch_ext_free (batch_view);
313313 if (ret != 0 ) {
@@ -413,8 +413,6 @@ int main(int argc, char ** argv) {
413413 // TODO: print sampling/grammar timings for all clients
414414 llama_perf_context_print (ctx);
415415
416- llama_batch_ext_free (batch);
417-
418416 llama_backend_free ();
419417
420418 LOG (" \n\n " );
0 commit comments