Skip to content

Commit 6b2a082

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Fix issue with partial UTF-8 string (#6317)
Summary: It will cause JNI exception if we don't pass in UTF-8 string. Alternative 1 (this): wait until we have complete UTF-8 tokens. Alternative 2 (?): Fix this from runner layer Alternative 3 (no): Change the API to use uint8_t array, but if we want to display on app in real time, this is still an issue. Pull Request resolved: #6317 Reviewed By: Riandy Differential Revision: D64580932 Pulled By: kirklandsign fbshipit-source-id: 341ea906097707fae0f97d32dea974ad44425083
1 parent 0eeea82 commit 6b2a082

File tree

1 file changed

+47
-3
lines changed

1 file changed

+47
-3
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <cassert>
109
#include <chrono>
11-
#include <iostream>
10+
#include <cstdint>
1211
#include <memory>
13-
#include <sstream>
1412
#include <string>
1513
#include <unordered_map>
1614
#include <vector>
@@ -33,6 +31,43 @@
3331
namespace llm = ::executorch::extension::llm;
3432
using ::executorch::runtime::Error;
3533

34+
namespace {
35+
bool utf8_check_validity(const char* str, size_t length) {
36+
for (size_t i = 0; i < length; ++i) {
37+
uint8_t byte = static_cast<uint8_t>(str[i]);
38+
if (byte >= 0x80) { // Non-ASCII byte
39+
if (i + 1 >= length) { // Incomplete sequence
40+
return false;
41+
}
42+
uint8_t next_byte = static_cast<uint8_t>(str[i + 1]);
43+
if ((byte & 0xE0) == 0xC0 &&
44+
(next_byte & 0xC0) == 0x80) { // 2-byte sequence
45+
i += 2;
46+
} else if (
47+
(byte & 0xF0) == 0xE0 && (next_byte & 0xC0) == 0x80 &&
48+
(i + 2 < length) &&
49+
(static_cast<uint8_t>(str[i + 2]) & 0xC0) ==
50+
0x80) { // 3-byte sequence
51+
i += 3;
52+
} else if (
53+
(byte & 0xF8) == 0xF0 && (next_byte & 0xC0) == 0x80 &&
54+
(i + 2 < length) &&
55+
(static_cast<uint8_t>(str[i + 2]) & 0xC0) == 0x80 &&
56+
(i + 3 < length) &&
57+
(static_cast<uint8_t>(str[i + 3]) & 0xC0) ==
58+
0x80) { // 4-byte sequence
59+
i += 4;
60+
} else {
61+
return false; // Invalid sequence
62+
}
63+
}
64+
}
65+
return true; // All bytes were valid
66+
}
67+
68+
std::string token_buffer;
69+
} // namespace
70+
3671
namespace executorch_jni {
3772

3873
class ExecuTorchLlamaCallbackJni
@@ -45,6 +80,15 @@ class ExecuTorchLlamaCallbackJni
4580
static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
4681
static const auto method =
4782
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onResult");
83+
84+
token_buffer += result;
85+
if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) {
86+
ET_LOG(
87+
Info, "Current token buffer is not valid UTF-8. Waiting for more.");
88+
return;
89+
}
90+
result = token_buffer;
91+
token_buffer = "";
4892
facebook::jni::local_ref<jstring> s = facebook::jni::make_jstring(result);
4993
method(self(), s);
5094
}

0 commit comments

Comments
 (0)