diff --git a/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp b/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp index fd546712d..8295d75d2 100644 --- a/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp +++ b/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp @@ -529,7 +529,9 @@ TextGenerationResult LlamaCppTextGeneration::generate(const TextGenerationReques result.prompt_tokens = prompt_tokens; result.inference_time_ms = duration.count(); - if (cancel_requested_.load()) { + if (decode_failed_) { + result.finish_reason = "error"; + } else if (cancel_requested_.load()) { result.finish_reason = "cancelled"; } else if (success) { result.finish_reason = tokens_generated >= request.max_tokens ? "length" : "stop"; @@ -548,7 +550,15 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques return false; } + // Clear KV cache before each new generation to avoid position conflicts on + // sequential calls (fixes #356: SIGABRT on second decode on Android arm64). + llama_memory_t mem = llama_get_memory(context_); + if (mem) { + llama_memory_clear(mem, true); + } + cancel_requested_.store(false); + decode_failed_ = false; std::string prompt = build_prompt(request); LOGI("Generating with prompt length: %zu", prompt.length()); @@ -717,6 +727,7 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques if (llama_decode(context_, batch) != 0) { LOGE("llama_decode failed during generation"); + decode_failed_ = true; break; } } @@ -725,7 +736,9 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques callback(stop_window); } - llama_memory_clear(llama_get_memory(context_), true); + if (llama_memory_t post_mem = llama_get_memory(context_)) { + llama_memory_clear(post_mem, true); + } llama_batch_free(batch); diff --git a/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.h b/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.h index 1387d5491..dc348a595 100644 --- a/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.h +++ b/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.h @@ -134,6 +134,7 @@ class LlamaCppTextGeneration { bool model_loaded_ = false; std::atomic cancel_requested_{false}; + std::atomic decode_failed_{false}; std::string model_path_; nlohmann::json model_config_; diff --git a/sdk/runanywhere-commons/src/backends/llamacpp/rac_llm_llamacpp.cpp b/sdk/runanywhere-commons/src/backends/llamacpp/rac_llm_llamacpp.cpp index ac6b40955..babfbf286 100644 --- a/sdk/runanywhere-commons/src/backends/llamacpp/rac_llm_llamacpp.cpp +++ b/sdk/runanywhere-commons/src/backends/llamacpp/rac_llm_llamacpp.cpp @@ -207,6 +207,13 @@ rac_result_t rac_llm_llamacpp_generate(rac_handle_t handle, const char* prompt, } RAC_LOG_INFO("LLM.LlamaCpp", "rac_llm_llamacpp_generate: generate() returned, tokens=%d", result.tokens_generated); + // finish_reason is std::string; TODO: migrate to enum if TextGenerationResult gains one + if (result.finish_reason == "error") { + RAC_LOG_ERROR("LLM.LlamaCpp", "rac_llm_llamacpp_generate: generation failed (e.g. llama_decode error)"); + rac_error_set_details("Generation failed: llama_decode returned non-zero"); + return RAC_ERROR_GENERATION_FAILED; + } + // Fill RAC result struct out_result->text = result.text.empty() ? nullptr : strdup(result.text.c_str()); out_result->completion_tokens = result.tokens_generated; diff --git a/sdk/runanywhere-commons/src/jni/runanywhere_commons_jni.cpp b/sdk/runanywhere-commons/src/jni/runanywhere_commons_jni.cpp index 5a0e3e51a..6e1b0bdb0 100644 --- a/sdk/runanywhere-commons/src/jni/runanywhere_commons_jni.cpp +++ b/sdk/runanywhere-commons/src/jni/runanywhere_commons_jni.cpp @@ -17,7 +17,6 @@ #include #include -#include #include #include #include @@ -568,6 +567,18 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racLlmComponentGenerate if (status != RAC_SUCCESS) { LOGe("racLlmComponentGenerate failed with status=%d", status); + rac_llm_result_free(&result); + const char* msg = rac_error_message(status); + jclass exClass = env->FindClass("java/lang/RuntimeException"); + if (exClass) { + char fallback[64]; + if (!msg || !*msg) { + snprintf(fallback, sizeof(fallback), "LLM generation failed (status=%d)", status); + msg = fallback; + } + env->ThrowNew(exClass, msg); + env->DeleteLocalRef(exClass); + } return nullptr; } @@ -863,6 +874,17 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racLlmComponentGenerate if (status != RAC_SUCCESS) { LOGe("rac_llm_component_generate_stream failed with status=%d", status); + const char* msg = rac_error_message(status); + jclass exClass = env->FindClass("java/lang/RuntimeException"); + if (exClass) { + char fallback[64]; + if (!msg || !*msg) { + snprintf(fallback, sizeof(fallback), "LLM stream generation failed (status=%d)", status); + msg = fallback; + } + env->ThrowNew(exClass, msg); + env->DeleteLocalRef(exClass); + } return nullptr; }