Skip to content

Commit b0db7fc

Browse files
committed
android : adapt to new API
1 parent 23d7407 commit b0db7fc

File tree

3 files changed

+31
-94
lines changed

3 files changed

+31
-94
lines changed

common/common.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,35 +1576,6 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons
15761576

15771577
#endif // LLAMA_USE_CURL
15781578

1579-
//
1580-
// Batch utils
1581-
//
1582-
1583-
// DEPRECATED
1584-
void common_batch_clear(struct llama_batch & batch) {
1585-
batch.n_tokens = 0;
1586-
}
1587-
1588-
// DEPRECATED
1589-
void common_batch_add(
1590-
struct llama_batch & batch,
1591-
llama_token id,
1592-
llama_pos pos,
1593-
const std::vector<llama_seq_id> & seq_ids,
1594-
bool logits) {
1595-
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
1596-
1597-
batch.token [batch.n_tokens] = id;
1598-
batch.pos [batch.n_tokens] = pos;
1599-
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
1600-
for (size_t i = 0; i < seq_ids.size(); ++i) {
1601-
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
1602-
}
1603-
batch.logits [batch.n_tokens] = logits;
1604-
1605-
batch.n_tokens++;
1606-
}
1607-
16081579
//
16091580
// Token utils
16101581
//

common/common.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -569,17 +569,6 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap
569569
// Batch utils
570570
//
571571

572-
// DEPRECATED
573-
void common_batch_clear(struct llama_batch & batch);
574-
575-
// DEPRECATED
576-
void common_batch_add(
577-
struct llama_batch & batch,
578-
llama_token id,
579-
llama_pos pos,
580-
const std::vector<llama_seq_id> & seq_ids,
581-
bool logits);
582-
583572
// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
584573
// this is meant to be temporary
585574
struct common_batch {

examples/llama.android/llama/src/main/cpp/llama-android.cpp

Lines changed: 31 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo
125125
ctx_params.n_threads = n_threads;
126126
ctx_params.n_threads_batch = n_threads;
127127

128-
llama_context * context = llama_new_context_with_model(model, ctx_params);
128+
llama_context * context = llama_init_from_model(model, ctx_params);
129129

130130
if (!context) {
131131
LOGe("llama_new_context_with_model() returned null)");
@@ -175,7 +175,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
175175

176176
const auto context = reinterpret_cast<llama_context *>(context_pointer);
177177
const auto model = reinterpret_cast<llama_model *>(model_pointer);
178-
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
178+
const auto batch = reinterpret_cast<llama_batch_ext *>(batch_pointer);
179179

180180
const int n_ctx = llama_n_ctx(context);
181181

@@ -186,19 +186,20 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
186186
for (nri = 0; nri < nr; nri++) {
187187
LOGi("Benchmark prompt processing (pp)");
188188

189-
common_batch_clear(*batch);
189+
llama_batch_ext_clear(batch);
190190

191191
const int n_tokens = pp;
192192
for (i = 0; i < n_tokens; i++) {
193-
common_batch_add(*batch, 0, i, { 0 }, false);
193+
llama_seq_id seq_id = 0;
194+
llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false);
194195
}
195196

196-
batch->logits[batch->n_tokens - 1] = true;
197+
llama_batch_ext_set_output_last(batch);
197198
llama_kv_self_clear(context);
198199

199200
const auto t_pp_start = ggml_time_us();
200-
if (llama_decode(context, *batch) != 0) {
201-
LOGi("llama_decode() failed during prompt processing");
201+
if (llama_decode_ext(context, batch) != 0) {
202+
LOGi("llama_decode_ext() failed during prompt processing");
202203
}
203204
const auto t_pp_end = ggml_time_us();
204205

@@ -210,14 +211,15 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
210211
const auto t_tg_start = ggml_time_us();
211212
for (i = 0; i < tg; i++) {
212213

213-
common_batch_clear(*batch);
214+
llama_batch_ext_clear(batch);
214215
for (j = 0; j < pl; j++) {
215-
common_batch_add(*batch, 0, i, { j }, true);
216+
llama_seq_id seq_id = j;
217+
llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, true);
216218
}
217219

218-
LOGi("llama_decode() text generation: %d", i);
219-
if (llama_decode(context, *batch) != 0) {
220-
LOGi("llama_decode() failed during text generation");
220+
LOGi("llama_decode_ext() text generation: %d", i);
221+
if (llama_decode_ext(context, batch) != 0) {
222+
LOGi("llama_decode_ext() failed during text generation");
221223
}
222224
}
223225

@@ -272,42 +274,15 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
272274
extern "C"
273275
JNIEXPORT jlong JNICALL
274276
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
275-
276-
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
277-
278-
llama_batch *batch = new llama_batch {
279-
0,
280-
nullptr,
281-
nullptr,
282-
nullptr,
283-
nullptr,
284-
nullptr,
285-
nullptr,
286-
};
287-
288-
if (embd) {
289-
batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
290-
} else {
291-
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
292-
}
293-
294-
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
295-
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
296-
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
297-
for (int i = 0; i < n_tokens; ++i) {
298-
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
299-
}
300-
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
277+
llama_batch_ext * batch = llama_batch_ext_init(n_tokens, n_seq_max);
301278

302279
return reinterpret_cast<jlong>(batch);
303280
}
304281

305282
extern "C"
306283
JNIEXPORT void JNICALL
307284
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
308-
//llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
309-
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
310-
delete batch;
285+
llama_batch_ext_free(reinterpret_cast<llama_batch_ext *>(batch_pointer));
311286
}
312287

313288
extern "C"
@@ -355,15 +330,15 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
355330

356331
const auto text = env->GetStringUTFChars(jtext, 0);
357332
const auto context = reinterpret_cast<llama_context *>(context_pointer);
358-
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
333+
const auto batch = reinterpret_cast<llama_batch_ext *>(batch_pointer);
359334

360335
bool parse_special = (format_chat == JNI_TRUE);
361336
const auto tokens_list = common_tokenize(context, text, true, parse_special);
362337

363338
auto n_ctx = llama_n_ctx(context);
364339
auto n_kv_req = tokens_list.size() + n_len;
365340

366-
LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
341+
LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", (int) n_len, (int) n_ctx, (int) n_kv_req);
367342

368343
if (n_kv_req > n_ctx) {
369344
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
@@ -373,23 +348,24 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
373348
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
374349
}
375350

376-
common_batch_clear(*batch);
351+
llama_batch_ext_clear(batch);
377352

378353
// evaluate the initial prompt
379354
for (auto i = 0; i < tokens_list.size(); i++) {
380-
common_batch_add(*batch, tokens_list[i], i, { 0 }, false);
355+
llama_seq_id seq_id = 0;
356+
llama_batch_ext_add_text(batch, tokens_list[i], i, &seq_id, 1, false);
381357
}
382358

383359
// llama_decode will output logits only for the last token of the prompt
384-
batch->logits[batch->n_tokens - 1] = true;
360+
llama_batch_ext_set_output_last(batch);
385361

386-
if (llama_decode(context, *batch) != 0) {
387-
LOGe("llama_decode() failed");
362+
if (llama_decode_ext(context, batch) != 0) {
363+
LOGe("llama_decode_ext() failed");
388364
}
389365

390366
env->ReleaseStringUTFChars(jtext, text);
391367

392-
return batch->n_tokens;
368+
return llama_batch_ext_get_n_tokens(batch);
393369
}
394370

395371
extern "C"
@@ -404,7 +380,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
404380
jobject intvar_ncur
405381
) {
406382
const auto context = reinterpret_cast<llama_context *>(context_pointer);
407-
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
383+
const auto batch = reinterpret_cast<llama_batch_ext *>(batch_pointer);
408384
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
409385
const auto model = llama_get_model(context);
410386
const auto vocab = llama_model_get_vocab(model);
@@ -433,13 +409,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
433409
new_token = env->NewStringUTF("");
434410
}
435411

436-
common_batch_clear(*batch);
437-
common_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
412+
llama_batch_ext_clear(batch);
413+
llama_seq_id seq_id = 0;
414+
llama_batch_ext_add_text(batch, new_token_id, n_cur, &seq_id, 1, true);
438415

439416
env->CallVoidMethod(intvar_ncur, la_int_var_inc);
440417

441-
if (llama_decode(context, *batch) != 0) {
442-
LOGe("llama_decode() returned null");
418+
if (llama_decode_ext(context, batch) != 0) {
419+
LOGe("llama_decode_ext() returned null");
443420
}
444421

445422
return new_token;

0 commit comments

Comments
 (0)