Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,9 @@ TextGenerationResult LlamaCppTextGeneration::generate(const TextGenerationReques
result.prompt_tokens = prompt_tokens;
result.inference_time_ms = duration.count();

if (cancel_requested_.load()) {
if (decode_failed_) {
result.finish_reason = "error";
} else if (cancel_requested_.load()) {
result.finish_reason = "cancelled";
} else if (success) {
result.finish_reason = tokens_generated >= request.max_tokens ? "length" : "stop";
Expand All @@ -548,7 +550,15 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques
return false;
}

// Clear KV cache before each new generation to avoid position conflicts on
// sequential calls (fixes #356: SIGABRT on second decode on Android arm64).
llama_memory_t mem = llama_get_memory(context_);
if (mem) {
llama_memory_clear(mem, true);
}

cancel_requested_.store(false);
decode_failed_ = false;

std::string prompt = build_prompt(request);
LOGI("Generating with prompt length: %zu", prompt.length());
Expand Down Expand Up @@ -717,6 +727,7 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques

if (llama_decode(context_, batch) != 0) {
LOGE("llama_decode failed during generation");
decode_failed_ = true;
break;
}
}
Expand All @@ -725,7 +736,9 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques
callback(stop_window);
}

llama_memory_clear(llama_get_memory(context_), true);
if (llama_memory_t post_mem = llama_get_memory(context_)) {
llama_memory_clear(post_mem, true);
}

llama_batch_free(batch);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class LlamaCppTextGeneration {

bool model_loaded_ = false;
std::atomic<bool> cancel_requested_{false};
std::atomic<bool> decode_failed_{false};

std::string model_path_;
nlohmann::json model_config_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ rac_result_t rac_llm_llamacpp_generate(rac_handle_t handle, const char* prompt,
}
RAC_LOG_INFO("LLM.LlamaCpp", "rac_llm_llamacpp_generate: generate() returned, tokens=%d", result.tokens_generated);

// finish_reason is std::string; TODO: migrate to enum if TextGenerationResult gains one
if (result.finish_reason == "error") {
RAC_LOG_ERROR("LLM.LlamaCpp", "rac_llm_llamacpp_generate: generation failed (e.g. llama_decode error)");
rac_error_set_details("Generation failed: llama_decode returned non-zero");
return RAC_ERROR_GENERATION_FAILED;
}

// Fill RAC result struct
out_result->text = result.text.empty() ? nullptr : strdup(result.text.c_str());
out_result->completion_tokens = result.tokens_generated;
Expand Down
24 changes: 23 additions & 1 deletion sdk/runanywhere-commons/src/jni/runanywhere_commons_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <jni.h>

#include <condition_variable>
#include <chrono>
#include <cstring>
#include <mutex>
#include <string>
Expand Down Expand Up @@ -568,6 +567,18 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racLlmComponentGenerate

if (status != RAC_SUCCESS) {
LOGe("racLlmComponentGenerate failed with status=%d", status);
rac_llm_result_free(&result);
const char* msg = rac_error_message(status);
jclass exClass = env->FindClass("java/lang/RuntimeException");
if (exClass) {
char fallback[64];
if (!msg || !*msg) {
snprintf(fallback, sizeof(fallback), "LLM generation failed (status=%d)", status);
msg = fallback;
}
env->ThrowNew(exClass, msg);
env->DeleteLocalRef(exClass);
}
return nullptr;
}

Expand Down Expand Up @@ -863,6 +874,17 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racLlmComponentGenerate

if (status != RAC_SUCCESS) {
LOGe("rac_llm_component_generate_stream failed with status=%d", status);
const char* msg = rac_error_message(status);
jclass exClass = env->FindClass("java/lang/RuntimeException");
if (exClass) {
char fallback[64];
if (!msg || !*msg) {
snprintf(fallback, sizeof(fallback), "LLM stream generation failed (status=%d)", status);
msg = fallback;
}
env->ThrowNew(exClass, msg);
env->DeleteLocalRef(exClass);
}
return nullptr;
}

Expand Down
Loading