diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 6ffc88d8103..1049b9da308 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -6,11 +6,9 @@ * LICENSE file in the root directory of this source tree. */ -#include #include -#include +#include #include -#include #include #include #include @@ -33,6 +31,43 @@ namespace llm = ::executorch::extension::llm; using ::executorch::runtime::Error; +namespace { +bool utf8_check_validity(const char* str, size_t length) { + for (size_t i = 0; i < length; ++i) { + uint8_t byte = static_cast(str[i]); + if (byte >= 0x80) { // Non-ASCII byte + if (i + 1 >= length) { // Incomplete sequence + return false; + } + uint8_t next_byte = static_cast(str[i + 1]); + if ((byte & 0xE0) == 0xC0 && + (next_byte & 0xC0) == 0x80) { // 2-byte sequence + i += 2; + } else if ( + (byte & 0xF0) == 0xE0 && (next_byte & 0xC0) == 0x80 && + (i + 2 < length) && + (static_cast(str[i + 2]) & 0xC0) == + 0x80) { // 3-byte sequence + i += 3; + } else if ( + (byte & 0xF8) == 0xF0 && (next_byte & 0xC0) == 0x80 && + (i + 2 < length) && + (static_cast(str[i + 2]) & 0xC0) == 0x80 && + (i + 3 < length) && + (static_cast(str[i + 3]) & 0xC0) == + 0x80) { // 4-byte sequence + i += 4; + } else { + return false; // Invalid sequence + } + } + } + return true; // All bytes were valid +} + +std::string token_buffer; +} // namespace + namespace executorch_jni { class ExecuTorchLlamaCallbackJni @@ -45,6 +80,15 @@ class ExecuTorchLlamaCallbackJni static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic(); static const auto method = cls->getMethod)>("onResult"); + + token_buffer += result; + if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) { + ET_LOG( + Info, "Current token buffer is not valid UTF-8. Waiting for more."); + return; + } + result = token_buffer; + token_buffer = ""; facebook::jni::local_ref s = facebook::jni::make_jstring(result); method(self(), s); }