@@ -125,7 +125,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo
125125    ctx_params.n_threads        = n_threads;
126126    ctx_params.n_threads_batch  = n_threads;
127127
128-     llama_context * context = llama_new_context_with_model (model, ctx_params);
128+     llama_context * context = llama_init_from_model (model, ctx_params);
129129
130130    if  (!context) {
131131        LOGe (" llama_new_context_with_model() returned null)"  );
@@ -175,7 +175,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
175175
176176    const  auto  context = reinterpret_cast <llama_context *>(context_pointer);
177177    const  auto  model = reinterpret_cast <llama_model *>(model_pointer);
178-     const  auto  batch = reinterpret_cast <llama_batch  *>(batch_pointer);
178+     const  auto  batch = reinterpret_cast <llama_batch_ext  *>(batch_pointer);
179179
180180    const  int  n_ctx = llama_n_ctx (context);
181181
@@ -186,19 +186,20 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
186186    for  (nri = 0 ; nri < nr; nri++) {
187187        LOGi (" Benchmark prompt processing (pp)"  );
188188
189-         common_batch_clear (* batch);
189+         llama_batch_ext_clear ( batch);
190190
191191        const  int  n_tokens = pp;
192192        for  (i = 0 ; i < n_tokens; i++) {
193-             common_batch_add (*batch, 0 , i, { 0  }, false );
193+             llama_seq_id seq_id = 0 ;
194+             llama_batch_ext_add_text (batch, 0 , i, &seq_id, 1 , false );
194195        }
195196
196-         batch-> logits [batch-> n_tokens  -  1 ] =  true ;
197+         llama_batch_ext_set_output_last ( batch) ;
197198        llama_kv_self_clear (context);
198199
199200        const  auto  t_pp_start = ggml_time_us ();
200-         if  (llama_decode (context, * batch) != 0 ) {
201-             LOGi (" llama_decode () failed during prompt processing"  );
201+         if  (llama_decode_ext (context, batch) != 0 ) {
202+             LOGi (" llama_decode_ext () failed during prompt processing"  );
202203        }
203204        const  auto  t_pp_end = ggml_time_us ();
204205
@@ -210,14 +211,15 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
210211        const  auto  t_tg_start = ggml_time_us ();
211212        for  (i = 0 ; i < tg; i++) {
212213
213-             common_batch_clear (* batch);
214+             llama_batch_ext_clear ( batch);
214215            for  (j = 0 ; j < pl; j++) {
215-                 common_batch_add (*batch, 0 , i, { j }, true );
216+                 llama_seq_id seq_id = j;
217+                 llama_batch_ext_add_text (batch, 0 , i, &seq_id, 1 , true );
216218            }
217219
218-             LOGi (" llama_decode () text generation: %d"  , i);
219-             if  (llama_decode (context, * batch) != 0 ) {
220-                 LOGi (" llama_decode () failed during text generation"  );
220+             LOGi (" llama_decode_ext () text generation: %d"  , i);
221+             if  (llama_decode_ext (context, batch) != 0 ) {
222+                 LOGi (" llama_decode_ext () failed during text generation"  );
221223            }
222224        }
223225
@@ -272,42 +274,15 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
272274extern  " C" 
273275JNIEXPORT jlong JNICALL
274276Java_android_llama_cpp_LLamaAndroid_new_1batch (JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
275- 
276-     //  Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
277- 
278-     llama_batch *batch = new  llama_batch {
279-         0 ,
280-         nullptr ,
281-         nullptr ,
282-         nullptr ,
283-         nullptr ,
284-         nullptr ,
285-         nullptr ,
286-     };
287- 
288-     if  (embd) {
289-         batch->embd  = (float  *) malloc (sizeof (float ) * n_tokens * embd);
290-     } else  {
291-         batch->token  = (llama_token *) malloc (sizeof (llama_token) * n_tokens);
292-     }
293- 
294-     batch->pos       = (llama_pos *)     malloc (sizeof (llama_pos)      * n_tokens);
295-     batch->n_seq_id  = (int32_t  *)       malloc (sizeof (int32_t )        * n_tokens);
296-     batch->seq_id    = (llama_seq_id **) malloc (sizeof (llama_seq_id *) * n_tokens);
297-     for  (int  i = 0 ; i < n_tokens; ++i) {
298-         batch->seq_id [i] = (llama_seq_id *) malloc (sizeof (llama_seq_id) * n_seq_max);
299-     }
300-     batch->logits    = (int8_t  *)        malloc (sizeof (int8_t )         * n_tokens);
277+     llama_batch_ext * batch = llama_batch_ext_init (n_tokens, n_seq_max);
301278
302279    return  reinterpret_cast <jlong>(batch);
303280}
304281
305282extern  " C" 
306283JNIEXPORT void  JNICALL
307284Java_android_llama_cpp_LLamaAndroid_free_1batch (JNIEnv *, jobject, jlong batch_pointer) {
308-     // llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
309-     const  auto  batch = reinterpret_cast <llama_batch *>(batch_pointer);
310-     delete  batch;
285+     llama_batch_ext_free (reinterpret_cast <llama_batch_ext *>(batch_pointer));
311286}
312287
313288extern  " C" 
@@ -355,15 +330,15 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
355330
356331    const  auto  text = env->GetStringUTFChars (jtext, 0 );
357332    const  auto  context = reinterpret_cast <llama_context *>(context_pointer);
358-     const  auto  batch = reinterpret_cast <llama_batch  *>(batch_pointer);
333+     const  auto  batch = reinterpret_cast <llama_batch_ext  *>(batch_pointer);
359334
360335    bool  parse_special = (format_chat == JNI_TRUE);
361336    const  auto  tokens_list = common_tokenize (context, text, true , parse_special);
362337
363338    auto  n_ctx = llama_n_ctx (context);
364339    auto  n_kv_req = tokens_list.size () + n_len;
365340
366-     LOGi (" n_len = %d, n_ctx = %d, n_kv_req = %d"  , n_len, n_ctx, n_kv_req);
341+     LOGi (" n_len = %d, n_ctx = %d, n_kv_req = %d"  , ( int )  n_len, ( int )  n_ctx, ( int )  n_kv_req);
367342
368343    if  (n_kv_req > n_ctx) {
369344        LOGe (" error: n_kv_req > n_ctx, the required KV cache size is not big enough"  );
@@ -373,23 +348,24 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
373348        LOGi (" token: `%s`-> %d "  , common_token_to_piece (context, id).c_str (), id);
374349    }
375350
376-     common_batch_clear (* batch);
351+     llama_batch_ext_clear ( batch);
377352
378353    //  evaluate the initial prompt
379354    for  (auto  i = 0 ; i < tokens_list.size (); i++) {
380-         common_batch_add (*batch, tokens_list[i], i, { 0  }, false );
355+         llama_seq_id seq_id = 0 ;
356+         llama_batch_ext_add_text (batch, tokens_list[i], i, &seq_id, 1 , false );
381357    }
382358
383359    //  llama_decode will output logits only for the last token of the prompt
384-     batch-> logits [batch-> n_tokens  -  1 ] =  true ;
360+     llama_batch_ext_set_output_last ( batch) ;
385361
386-     if  (llama_decode (context, * batch) != 0 ) {
387-         LOGe (" llama_decode () failed"  );
362+     if  (llama_decode_ext (context, batch) != 0 ) {
363+         LOGe (" llama_decode_ext () failed"  );
388364    }
389365
390366    env->ReleaseStringUTFChars (jtext, text);
391367
392-     return  batch-> n_tokens ;
368+     return  llama_batch_ext_get_n_tokens ( batch) ;
393369}
394370
395371extern  " C" 
@@ -404,7 +380,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
404380        jobject intvar_ncur
405381) {
406382    const  auto  context = reinterpret_cast <llama_context *>(context_pointer);
407-     const  auto  batch   = reinterpret_cast <llama_batch    *>(batch_pointer);
383+     const  auto  batch   = reinterpret_cast <llama_batch_ext  *>(batch_pointer);
408384    const  auto  sampler = reinterpret_cast <llama_sampler *>(sampler_pointer);
409385    const  auto  model = llama_get_model (context);
410386    const  auto  vocab = llama_model_get_vocab (model);
@@ -433,13 +409,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
433409        new_token = env->NewStringUTF (" "  );
434410    }
435411
436-     common_batch_clear (*batch);
437-     common_batch_add (*batch, new_token_id, n_cur, { 0  }, true );
412+     llama_batch_ext_clear (batch);
413+     llama_seq_id seq_id = 0 ;
414+     llama_batch_ext_add_text (batch, new_token_id, n_cur, &seq_id, 1 , true );
438415
439416    env->CallVoidMethod (intvar_ncur, la_int_var_inc);
440417
441-     if  (llama_decode (context, * batch) != 0 ) {
442-         LOGe (" llama_decode () returned null"  );
418+     if  (llama_decode_ext (context, batch) != 0 ) {
419+         LOGe (" llama_decode_ext () returned null"  );
443420    }
444421
445422    return  new_token;
0 commit comments