diff --git a/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp b/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp index 27733806b..e645006d0 100644 --- a/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp +++ b/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include "rac/core/rac_logger.h" @@ -15,39 +16,38 @@ namespace runanywhere { -// ============================================================================= -// UTF-8 VALIDATION HELPER -// ============================================================================= - -static bool is_valid_utf8(const char* string) { - if (!string) - return true; - - const unsigned char* bytes = (const unsigned char*)string; - int num; - - while (*bytes != 0x00) { - if ((*bytes & 0x80) == 0x00) { - num = 1; - } else if ((*bytes & 0xE0) == 0xC0) { - num = 2; - } else if ((*bytes & 0xF0) == 0xE0) { - num = 3; - } else if ((*bytes & 0xF8) == 0xF0) { - num = 4; - } else { - return false; - } +// UTF-8 STATE MACHINE (DFA) + +struct Utf8State { + + uint32_t state = 0; + + // Bjoern Hoehrmann LUT + bool process(uint8_t byte) { + static const uint8_t utf8d[] = { + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 00..1f + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 20..3f + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 40..5f + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 60..7f + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, // 80..9f + 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, // a0..bf + 8,8,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, // c0..df + 0xa,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x4,0x3,0x3, // e0..ef + 0xb,0x6,0x6,0x6,0x5,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8, // f0..ff + 0x0,0x1,0x2,0x3,0x5,0x8,0x7,0x1,0x1,0x1,0x4,0x6,0x1,0x1,0x1,0x1, // s0..s0 + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,0,1,0,1,1,1,1,1,1, // s1..s2 + 1,2,1,1,1,1,1,2,1,2,1,1,1,1,1,1,1,1,1,1,1,1,1,2,1,1,1,1,1,1,1,1, // s3..s4 + 1,2,1,1,1,1,1,1,1,2,1,1,1,1,1,1,1,1,1,1,1,1,1,3,1,3,1,1,1,1,1,1, // s5..s6 + 1,3,1,1,1,1,1,3,1,3,1,1,1,1,1,1,1,3,1,1,1,1,1,1,1,1,1,1,1,1,1,1, // s7..s8 + }; - bytes += 1; - for (int i = 1; i < num; ++i) { - if ((*bytes & 0xC0) != 0x80) - return false; - bytes += 1; - } + uint32_t type = utf8d[byte]; + state = utf8d[256 + state * 16 + type]; + return (state == 0); } - return true; -} + + void reset() { state = 0; } +}; // ============================================================================= // LOG CALLBACK @@ -457,12 +457,8 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques } int effective_max_tokens = std::min(request.max_tokens, available_tokens); - if (effective_max_tokens < request.max_tokens) { - LOGI("Capping max_tokens: %d → %d (context=%d, prompt=%d tokens)", request.max_tokens, - effective_max_tokens, n_ctx, prompt_tokens); - } - LOGI("Generation: prompt_tokens=%d, max_tokens=%d, context=%d", prompt_tokens, - effective_max_tokens, n_ctx); + LOGI("Generation: prompt_tokens=%d, max_tokens=%d, context=%d", + prompt_tokens, effective_max_tokens, n_ctx); llama_batch batch = llama_batch_init(n_ctx, 0, 1); @@ -481,10 +477,27 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques llama_sampler_reset(sampler_); const auto vocab = llama_model_get_vocab(model_); - std::string cached_token_chars; - std::string accumulated_text; + + static const std::vector STOP_SEQUENCES = { + "<|im_end|>", "<|eot_id|>", "", "<|end|>", "<|endoftext|>", + "\n\nUser:", "\n\nHuman:", + }; + + static const size_t MAX_STOP_LEN = []{ + size_t m = 0; + for (const auto& s : STOP_SEQUENCES) m = std::max(m, s.size()); + return m; + }(); + + std::string stop_window; + stop_window.reserve(MAX_STOP_LEN * 2); + + std::string partial_utf8_buffer; + partial_utf8_buffer.reserve(8); + int n_cur = batch.n_tokens; int tokens_generated = 0; + bool stop_sequence_hit = false; while (tokens_generated < effective_max_tokens && !cancel_requested_.load()) { const llama_token new_token_id = llama_sampler_sample(sampler_, context_, -1); @@ -496,41 +509,55 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques break; } - auto new_token_chars = common_token_to_piece(context_, new_token_id); - cached_token_chars += new_token_chars; - accumulated_text += new_token_chars; - - static const std::vector stop_sequences = { - "<|im_end|>", - "<|eot_id|>", - "", - "<|end|>", - "<|endoftext|>", - "\n\nUser:", - "\n\nHuman:", - }; + const std::string new_token_chars = + common_token_to_piece(context_, new_token_id); - bool hit_stop_sequence = false; - for (const auto& stop_seq : stop_sequences) { - size_t pos = accumulated_text.find(stop_seq); - if (pos != std::string::npos) { - LOGI("Stop sequence detected: %s", stop_seq.c_str()); - hit_stop_sequence = true; - break; + partial_utf8_buffer.append(new_token_chars); + + Utf8State scanner_state; + size_t valid_upto = 0; + for (size_t i = 0; i < partial_utf8_buffer.size(); ++i) { + scanner_state.process(static_cast(partial_utf8_buffer[i])); + if (scanner_state.state == 0) { + valid_upto = i + 1; } } - if (hit_stop_sequence) { - break; - } + if (valid_upto > 0) { + std::string valid_chunk = partial_utf8_buffer.substr(0, valid_upto); + stop_window.append(valid_chunk); + partial_utf8_buffer.erase(0, valid_upto); + + size_t found_stop_pos = std::string::npos; + for (const auto& stop_seq : STOP_SEQUENCES) { + size_t pos = stop_window.find(stop_seq); + if (pos != std::string::npos) { + if (found_stop_pos == std::string::npos || pos < found_stop_pos) { + found_stop_pos = pos; + } + } + } - if (is_valid_utf8(cached_token_chars.c_str())) { - if (!callback(cached_token_chars)) { - LOGI("Generation cancelled by callback"); - cancel_requested_.store(true); + if (found_stop_pos != std::string::npos) { + LOGI("Stop sequence detected"); + stop_sequence_hit = true; + if (found_stop_pos > 0) { + if (!callback(stop_window.substr(0, found_stop_pos))) { + cancel_requested_.store(true); + } + } break; } - cached_token_chars.clear(); + + if (stop_window.size() > MAX_STOP_LEN) { + size_t safe_len = stop_window.size() - MAX_STOP_LEN; + if (!callback(stop_window.substr(0, safe_len))) { + LOGI("Generation cancelled by callback"); + cancel_requested_.store(true); + break; + } + stop_window.erase(0, safe_len); + } } batch.n_tokens = 0; @@ -545,8 +572,8 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques } } - if (!cached_token_chars.empty() && is_valid_utf8(cached_token_chars.c_str())) { - callback(cached_token_chars); + if (!cancel_requested_.load() && !stop_sequence_hit && !stop_window.empty()) { + callback(stop_window); } llama_memory_clear(llama_get_memory(context_), true);