33// #include <android/asset_manager_jni.h>
44#include < android/log.h>
55#include < cstdlib>
6+ #include < ctime>
67#include < sys/sysinfo.h>
78#include < string>
89#include < thread>
@@ -21,6 +22,13 @@ static inline int min(int a, int b) {
2122 return (a < b) ? a : b;
2223}
2324
25+ static void log_callback (lm_ggml_log_level level, const char * fmt, void * data) {
26+ if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print (ANDROID_LOG_ERROR, TAG, fmt, data);
27+ else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print (ANDROID_LOG_INFO, TAG, fmt, data);
28+ else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print (ANDROID_LOG_WARN, TAG, fmt, data);
29+ else __android_log_print (ANDROID_LOG_DEFAULT, TAG, fmt, data);
30+ }
31+
2432extern " C" {
2533
2634// Method to create WritableMap
@@ -139,14 +147,20 @@ Java_com_rnllama_LlamaContext_initContext(
139147 jint n_gpu_layers, // TODO: Support this
140148 jboolean use_mlock,
141149 jboolean use_mmap,
150+ jboolean vocab_only,
142151 jstring lora_str,
143152 jfloat lora_scaled,
144153 jfloat rope_freq_base,
145154 jfloat rope_freq_scale
146155) {
147156 UNUSED (thiz);
148157
149- gpt_params defaultParams;
158+ common_params defaultParams;
159+
160+ defaultParams.vocab_only = vocab_only;
161+ if (vocab_only) {
162+ defaultParams.warmup = false ;
163+ }
150164
151165 const char *model_path_chars = env->GetStringUTFChars (model_path_str, nullptr );
152166 defaultParams.model = model_path_chars;
@@ -159,10 +173,10 @@ Java_com_rnllama_LlamaContext_initContext(
159173 int max_threads = std::thread::hardware_concurrency ();
160174 // Use 2 threads by default on 4-core devices, 4 threads on more cores
161175 int default_n_threads = max_threads == 4 ? 2 : min (4 , max_threads);
162- defaultParams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
176+ defaultParams.cpuparams . n_threads = n_threads > 0 ? n_threads : default_n_threads;
163177
164178 defaultParams.n_gpu_layers = n_gpu_layers;
165-
179+
166180 defaultParams.use_mlock = use_mlock;
167181 defaultParams.use_mmap = use_mmap;
168182
@@ -235,7 +249,7 @@ Java_com_rnllama_LlamaContext_getFormattedChat(
235249 UNUSED (thiz);
236250 auto llama = context_map[(long ) context_ptr];
237251
238- std::vector<llama_chat_msg > chat;
252+ std::vector<common_chat_msg > chat;
239253
240254 int messages_len = env->GetArrayLength (messages);
241255 for (int i = 0 ; i < messages_len; i++) {
@@ -259,7 +273,7 @@ Java_com_rnllama_LlamaContext_getFormattedChat(
259273 }
260274
261275 const char *tmpl_chars = env->GetStringUTFChars (chat_template, nullptr );
262- std::string formatted_chat = llama_chat_apply_template (llama->model , tmpl_chars, chat, true );
276+ std::string formatted_chat = common_chat_apply_template (llama->model , tmpl_chars, chat, true );
263277
264278 return env->NewStringUTF (formatted_chat.c_str ());
265279}
@@ -364,6 +378,8 @@ Java_com_rnllama_LlamaContext_doCompletion(
364378 jint top_k,
365379 jfloat top_p,
366380 jfloat min_p,
381+ jfloat xtc_threshold,
382+ jfloat xtc_probability,
367383 jfloat tfs_z,
368384 jfloat typical_p,
369385 jint seed,
@@ -377,18 +393,18 @@ Java_com_rnllama_LlamaContext_doCompletion(
377393
378394 llama->rewind ();
379395
380- llama_reset_timings (llama->ctx );
396+ // llama_reset_timings(llama->ctx);
381397
382398 llama->params .prompt = env->GetStringUTFChars (prompt, nullptr );
383- llama->params .seed = seed;
399+ llama->params .sparams . seed = (seed == - 1 ) ? time ( NULL ) : seed;
384400
385401 int max_threads = std::thread::hardware_concurrency ();
386402 // Use 2 threads by default on 4-core devices, 4 threads on more cores
387403 int default_n_threads = max_threads == 4 ? 2 : min (4 , max_threads);
388- llama->params .n_threads = n_threads > 0 ? n_threads : default_n_threads;
404+ llama->params .cpuparams . n_threads = n_threads > 0 ? n_threads : default_n_threads;
389405
390406 llama->params .n_predict = n_predict;
391- llama->params .ignore_eos = ignore_eos;
407+ llama->params .sparams . ignore_eos = ignore_eos;
392408
393409 auto & sparams = llama->params .sparams ;
394410 sparams.temp = temperature;
@@ -404,13 +420,15 @@ Java_com_rnllama_LlamaContext_doCompletion(
404420 sparams.top_p = top_p;
405421 sparams.min_p = min_p;
406422 sparams.tfs_z = tfs_z;
407- sparams.typical_p = typical_p;
423+ sparams.typ_p = typical_p;
408424 sparams.n_probs = n_probs;
409425 sparams.grammar = env->GetStringUTFChars (grammar, nullptr );
426+ sparams.xtc_threshold = xtc_threshold;
427+ sparams.xtc_probability = xtc_probability;
410428
411429 sparams.logit_bias .clear ();
412430 if (ignore_eos) {
413- sparams.logit_bias [llama_token_eos (llama->model )] = -INFINITY;
431+ sparams.logit_bias [llama_token_eos (llama->model )]. bias = -INFINITY;
414432 }
415433
416434 const int n_vocab = llama_n_vocab (llama_get_model (llama->ctx ));
@@ -424,9 +442,9 @@ Java_com_rnllama_LlamaContext_doCompletion(
424442 llama_token tok = static_cast <llama_token>(doubleArray[0 ]);
425443 if (tok >= 0 && tok < n_vocab) {
426444 if (doubleArray[1 ] != 0 ) { // If the second element is not false (0)
427- sparams.logit_bias [tok] = doubleArray[1 ];
445+ sparams.logit_bias [tok]. bias = doubleArray[1 ];
428446 } else {
429- sparams.logit_bias [tok] = -INFINITY;
447+ sparams.logit_bias [tok]. bias = -INFINITY;
430448 }
431449 }
432450
@@ -460,7 +478,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
460478 if (token_with_probs.tok == -1 || llama->incomplete ) {
461479 continue ;
462480 }
463- const std::string token_text = llama_token_to_piece (llama->ctx , token_with_probs.tok );
481+ const std::string token_text = common_token_to_piece (llama->ctx , token_with_probs.tok );
464482
465483 size_t pos = std::min (sent_count, llama->generated_text .size ());
466484
@@ -495,7 +513,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
495513 putString (env, tokenResult, " token" , to_send.c_str ());
496514
497515 if (llama->params .sparams .n_probs > 0 ) {
498- const std::vector<llama_token> to_send_toks = llama_tokenize (llama->ctx , to_send, false );
516+ const std::vector<llama_token> to_send_toks = common_tokenize (llama->ctx , to_send, false );
499517 size_t probs_pos = std::min (sent_token_probs_index, llama->generated_token_probs .size ());
500518 size_t probs_stop_pos = std::min (sent_token_probs_index + to_send_toks.size (), llama->generated_token_probs .size ());
501519 if (probs_pos < probs_stop_pos) {
@@ -512,7 +530,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
512530 }
513531 }
514532
515- llama_print_timings (llama->ctx );
533+ llama_perf_context_print (llama->ctx );
516534 llama->is_predicting = false ;
517535
518536 auto result = createWriteableMap (env);
@@ -527,16 +545,17 @@ Java_com_rnllama_LlamaContext_doCompletion(
527545 putString (env, result, " stopping_word" , llama->stopping_word .c_str ());
528546 putInt (env, result, " tokens_cached" , llama->n_past );
529547
530- const auto timings = llama_get_timings (llama->ctx );
548+ const auto timings_token = llama_perf_context (llama -> ctx);
549+
531550 auto timingsResult = createWriteableMap (env);
532- putInt (env, timingsResult, " prompt_n" , timings .n_p_eval );
533- putInt (env, timingsResult, " prompt_ms" , timings .t_p_eval_ms );
534- putInt (env, timingsResult, " prompt_per_token_ms" , timings .t_p_eval_ms / timings .n_p_eval );
535- putDouble (env, timingsResult, " prompt_per_second" , 1e3 / timings .t_p_eval_ms * timings .n_p_eval );
536- putInt (env, timingsResult, " predicted_n" , timings .n_eval );
537- putInt (env, timingsResult, " predicted_ms" , timings .t_eval_ms );
538- putInt (env, timingsResult, " predicted_per_token_ms" , timings .t_eval_ms / timings .n_eval );
539- putDouble (env, timingsResult, " predicted_per_second" , 1e3 / timings .t_eval_ms * timings .n_eval );
551+ putInt (env, timingsResult, " prompt_n" , timings_token .n_p_eval );
552+ putInt (env, timingsResult, " prompt_ms" , timings_token .t_p_eval_ms );
553+ putInt (env, timingsResult, " prompt_per_token_ms" , timings_token .t_p_eval_ms / timings_token .n_p_eval );
554+ putDouble (env, timingsResult, " prompt_per_second" , 1e3 / timings_token .t_p_eval_ms * timings_token .n_p_eval );
555+ putInt (env, timingsResult, " predicted_n" , timings_token .n_eval );
556+ putInt (env, timingsResult, " predicted_ms" , timings_token .t_eval_ms );
557+ putInt (env, timingsResult, " predicted_per_token_ms" , timings_token .t_eval_ms / timings_token .n_eval );
558+ putDouble (env, timingsResult, " predicted_per_second" , 1e3 / timings_token .t_eval_ms * timings_token .n_eval );
540559
541560 putMap (env, result, " timings" , timingsResult);
542561
@@ -569,7 +588,7 @@ Java_com_rnllama_LlamaContext_tokenize(
569588
570589 const char *text_chars = env->GetStringUTFChars (text, nullptr );
571590
572- const std::vector<llama_token> toks = llama_tokenize (
591+ const std::vector<llama_token> toks = common_tokenize (
573592 llama->ctx ,
574593 text_chars,
575594 false
@@ -623,8 +642,8 @@ Java_com_rnllama_LlamaContext_embedding(
623642
624643 llama->rewind ();
625644
626- llama_reset_timings (llama->ctx );
627-
645+ llama_perf_context_reset (llama->ctx );
646+
628647 llama->params .prompt = text_chars;
629648
630649 llama->params .n_predict = 0 ;
@@ -681,9 +700,16 @@ Java_com_rnllama_LlamaContext_freeContext(
681700 }
682701 if (llama->ctx_sampling != nullptr )
683702 {
684- llama_sampling_free (llama->ctx_sampling );
703+ common_sampler_free (llama->ctx_sampling );
685704 }
686705 context_map.erase ((long ) llama->ctx );
687706}
688707
708+ JNIEXPORT void JNICALL
709+ Java_com_rnllama_LlamaContext_logToAndroid (JNIEnv *env, jobject thiz) {
710+ UNUSED (env);
711+ UNUSED (thiz);
712+ llama_log_set (log_callback, NULL );
713+ }
714+
689715} // extern "C"
0 commit comments