Skip to content

Commit c3f08d2

Browse files
sakirr05sakirr
andauthored
fix: clear KV cache and reset batch state between sequential decode c (RunanywhereAI#393)
* fix: clear KV cache and reset batch state between sequential decode calls on arm64 * fix: address bot review comments - null guard, decode failure flag, error details, and JNI exception fallback * fix: make decode_failed_ std::atomic for thread safety (review) --------- Co-authored-by: sakirr <sakirahmed75531@gmail.com>
1 parent 7094a8c commit c3f08d2

File tree

4 files changed

+46
-3
lines changed

4 files changed

+46
-3
lines changed

sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,9 @@ TextGenerationResult LlamaCppTextGeneration::generate(const TextGenerationReques
529529
result.prompt_tokens = prompt_tokens;
530530
result.inference_time_ms = duration.count();
531531

532-
if (cancel_requested_.load()) {
532+
if (decode_failed_) {
533+
result.finish_reason = "error";
534+
} else if (cancel_requested_.load()) {
533535
result.finish_reason = "cancelled";
534536
} else if (success) {
535537
result.finish_reason = tokens_generated >= request.max_tokens ? "length" : "stop";
@@ -548,7 +550,15 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques
548550
return false;
549551
}
550552

553+
// Clear KV cache before each new generation to avoid position conflicts on
554+
// sequential calls (fixes #356: SIGABRT on second decode on Android arm64).
555+
llama_memory_t mem = llama_get_memory(context_);
556+
if (mem) {
557+
llama_memory_clear(mem, true);
558+
}
559+
551560
cancel_requested_.store(false);
561+
decode_failed_ = false;
552562

553563
std::string prompt = build_prompt(request);
554564
LOGI("Generating with prompt length: %zu", prompt.length());
@@ -717,6 +727,7 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques
717727

718728
if (llama_decode(context_, batch) != 0) {
719729
LOGE("llama_decode failed during generation");
730+
decode_failed_ = true;
720731
break;
721732
}
722733
}
@@ -725,7 +736,9 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques
725736
callback(stop_window);
726737
}
727738

728-
llama_memory_clear(llama_get_memory(context_), true);
739+
if (llama_memory_t post_mem = llama_get_memory(context_)) {
740+
llama_memory_clear(post_mem, true);
741+
}
729742

730743
llama_batch_free(batch);
731744

sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class LlamaCppTextGeneration {
134134

135135
bool model_loaded_ = false;
136136
std::atomic<bool> cancel_requested_{false};
137+
std::atomic<bool> decode_failed_{false};
137138

138139
std::string model_path_;
139140
nlohmann::json model_config_;

sdk/runanywhere-commons/src/backends/llamacpp/rac_llm_llamacpp.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,13 @@ rac_result_t rac_llm_llamacpp_generate(rac_handle_t handle, const char* prompt,
207207
}
208208
RAC_LOG_INFO("LLM.LlamaCpp", "rac_llm_llamacpp_generate: generate() returned, tokens=%d", result.tokens_generated);
209209

210+
// finish_reason is std::string; TODO: migrate to enum if TextGenerationResult gains one
211+
if (result.finish_reason == "error") {
212+
RAC_LOG_ERROR("LLM.LlamaCpp", "rac_llm_llamacpp_generate: generation failed (e.g. llama_decode error)");
213+
rac_error_set_details("Generation failed: llama_decode returned non-zero");
214+
return RAC_ERROR_GENERATION_FAILED;
215+
}
216+
210217
// Fill RAC result struct
211218
out_result->text = result.text.empty() ? nullptr : strdup(result.text.c_str());
212219
out_result->completion_tokens = result.tokens_generated;

sdk/runanywhere-commons/src/jni/runanywhere_commons_jni.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include <jni.h>
1818

1919
#include <condition_variable>
20-
#include <chrono>
2120
#include <cstring>
2221
#include <mutex>
2322
#include <string>
@@ -568,6 +567,18 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racLlmComponentGenerate
568567

569568
if (status != RAC_SUCCESS) {
570569
LOGe("racLlmComponentGenerate failed with status=%d", status);
570+
rac_llm_result_free(&result);
571+
const char* msg = rac_error_message(status);
572+
jclass exClass = env->FindClass("java/lang/RuntimeException");
573+
if (exClass) {
574+
char fallback[64];
575+
if (!msg || !*msg) {
576+
snprintf(fallback, sizeof(fallback), "LLM generation failed (status=%d)", status);
577+
msg = fallback;
578+
}
579+
env->ThrowNew(exClass, msg);
580+
env->DeleteLocalRef(exClass);
581+
}
571582
return nullptr;
572583
}
573584

@@ -863,6 +874,17 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racLlmComponentGenerate
863874

864875
if (status != RAC_SUCCESS) {
865876
LOGe("rac_llm_component_generate_stream failed with status=%d", status);
877+
const char* msg = rac_error_message(status);
878+
jclass exClass = env->FindClass("java/lang/RuntimeException");
879+
if (exClass) {
880+
char fallback[64];
881+
if (!msg || !*msg) {
882+
snprintf(fallback, sizeof(fallback), "LLM stream generation failed (status=%d)", status);
883+
msg = fallback;
884+
}
885+
env->ThrowNew(exClass, msg);
886+
env->DeleteLocalRef(exClass);
887+
}
866888
return nullptr;
867889
}
868890

0 commit comments

Comments
 (0)