@@ -217,23 +217,23 @@ int main(int argc, char ** argv) {
217217 common_kv_cache_dump_view_seqs (kvc_view, 40 );
218218 }
219219
220- llama_batch_ext_clear ( batch.get () );
220+ batch.clear ( );
221221
222222 // decode any currently ongoing sequences
223223 for (auto & client : clients) {
224224 if (client.seq_id == -1 ) {
225225 continue ;
226226 }
227227
228- client.i_batch = llama_batch_ext_get_n_tokens ( batch.get () );
228+ client.i_batch = batch.n_tokens ( );
229229
230230 llama_seq_id seq_id = client.id + 1 ;
231231 batch.add_text (client.sampled , n_tokens_system + client.n_prompt + client.n_decoded , seq_id, true );
232232
233233 client.n_decoded += 1 ;
234234 }
235235
236- if (llama_batch_ext_get_n_tokens ( batch.get () ) == 0 ) {
236+ if (batch.n_tokens ( ) == 0 ) {
237237 // all sequences have ended - clear the entire KV cache
238238 for (int i = 1 ; i <= n_clients; ++i) {
239239 llama_kv_self_seq_rm (ctx, i, -1 , -1 );
@@ -245,7 +245,7 @@ int main(int argc, char ** argv) {
245245 }
246246
247247 // insert new sequences for decoding
248- if (cont_batching || llama_batch_ext_get_n_tokens ( batch.get () ) == 0 ) {
248+ if (cont_batching || batch.n_tokens ( ) == 0 ) {
249249 for (auto & client : clients) {
250250 if (client.seq_id == -1 && g_seq_id < n_seq) {
251251 client.seq_id = g_seq_id;
@@ -269,13 +269,13 @@ int main(int argc, char ** argv) {
269269 }
270270
271271 // extract the logits only for the last token
272- if (llama_batch_ext_get_n_tokens ( batch.get () ) > 0 ) {
272+ if (batch.n_tokens ( ) > 0 ) {
273273 llama_batch_ext_set_output_last (batch.get ());
274274 }
275275
276276 client.n_prompt = tokens_prompt.size ();
277277 client.n_decoded = 0 ;
278- client.i_batch = llama_batch_ext_get_n_tokens ( batch.get () ) - 1 ;
278+ client.i_batch = batch.n_tokens ( ) - 1 ;
279279
280280 LOG_INF (" \033 [31mClient %3d, seq %4d, started decoding ...\033 [0m\n " , client.id , client.seq_id );
281281
@@ -289,14 +289,14 @@ int main(int argc, char ** argv) {
289289 }
290290 }
291291
292- if (llama_batch_ext_get_n_tokens ( batch.get () ) == 0 ) {
292+ if (batch.n_tokens ( ) == 0 ) {
293293 break ;
294294 }
295295
296296 // process in chunks of params.n_batch
297297 int32_t n_batch = params.n_batch ;
298298
299- int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens ( batch.get () );
299+ int32_t n_tokens_in_batch = batch.n_tokens ( );
300300 for (int32_t i = 0 ; i < (int32_t ) n_tokens_in_batch; i += n_batch) {
301301 // experiment: process in powers of 2
302302 // if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
0 commit comments