Skip to content

Commit cbe41c1

Browse files
committed
Continue debugging
1 parent 0b5ff1e commit cbe41c1

File tree

1 file changed

+38
-6
lines changed

1 file changed

+38
-6
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
*/
88

99
#include <cassert>
10+
#include <cstddef>
11+
#include <cstdint>
1012
#include <chrono>
1113
#include <codecvt>
1214
#include <iostream>
@@ -36,6 +38,30 @@
3638
namespace llm = ::executorch::extension::llm;
3739
using ::executorch::runtime::Error;
3840

41+
namespace {
42+
bool utf8_check_validity(const char* str, size_t length) {
43+
for (size_t i = 0; i < length; ++i) {
44+
uint8_t byte = static_cast<uint8_t>(str[i]);
45+
if (byte >= 0x80) { // Non-ASCII byte
46+
if (i + 1 >= length) { // Incomplete sequence
47+
return false;
48+
}
49+
uint8_t next_byte = static_cast<uint8_t>(str[i + 1]);
50+
if ((byte & 0xE0) == 0xC0 && (next_byte & 0xC0) == 0x80) { // 2-byte sequence
51+
i += 2;
52+
} else if ((byte & 0xF0) == 0xE0 && (next_byte & 0xC0) == 0x80 && (i + 2 < length) && (static_cast<uint8_t>(str[i + 2]) & 0xC0) == 0x80) { // 3-byte sequence
53+
i += 3;
54+
} else {
55+
return false; // Invalid sequence
56+
}
57+
}
58+
}
59+
return true; // All bytes were valid
60+
}
61+
62+
std::string token_buffer;
63+
}
64+
3965
namespace executorch_jni {
4066

4167
class ExecuTorchLlamaCallbackJni
@@ -48,9 +74,17 @@ class ExecuTorchLlamaCallbackJni
4874
static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
4975
static const auto method =
5076
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onResult");
51-
// static std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t> converter;
52-
// std::u16string result_u16 = converter.from_bytes(result);
53-
// __android_log_print(ANDROID_LOG_ERROR, "ExecuTorchDBG", "U16:%s", result_u16.c_str());
77+
78+
token_buffer += result;
79+
if (utf8_check_validity(token_buffer.c_str(), token_buffer.size())) {
80+
__android_log_print(ANDROID_LOG_ERROR, "ExecuTorchDBG", "CONTINUE8:%s", token_buffer.c_str());
81+
return;
82+
}
83+
result = token_buffer;
84+
token_buffer = "";
85+
static std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t> converter;
86+
std::u16string result_u16 = converter.from_bytes(result);
87+
__android_log_print(ANDROID_LOG_ERROR, "ExecuTorchDBG", "U16:%s", result_u16.c_str());
5488
facebook::jni::local_ref<jstring> s = facebook::jni::make_jstring(result);
5589
method(self(), s);
5690
}
@@ -153,14 +187,12 @@ class ExecuTorchLlamaJni
153187
[callback](const llm::Stats& result) { callback->onStats(result); },
154188
echo);
155189
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
156-
std::string my_result;
157190
runner_->generate(
158191
prompt->toStdString(),
159192
seq_len,
160-
[callback, &my_result](std::string result) { my_result += result; },
193+
[callback](std::string result) { callback->onResult(result); },
161194
[callback](const llm::Stats& result) { callback->onStats(result); },
162195
echo);
163-
callback->onResult(my_result);
164196
}
165197
return 0;
166198
}

0 commit comments

Comments
 (0)