Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cassert>
#include <chrono>
#include <iostream>
#include <cstdint>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
Expand All @@ -33,6 +31,43 @@
namespace llm = ::executorch::extension::llm;
using ::executorch::runtime::Error;

namespace {
bool utf8_check_validity(const char* str, size_t length) {
for (size_t i = 0; i < length; ++i) {
uint8_t byte = static_cast<uint8_t>(str[i]);
if (byte >= 0x80) { // Non-ASCII byte
if (i + 1 >= length) { // Incomplete sequence
return false;
}
uint8_t next_byte = static_cast<uint8_t>(str[i + 1]);
if ((byte & 0xE0) == 0xC0 &&
(next_byte & 0xC0) == 0x80) { // 2-byte sequence
i += 2;
} else if (
(byte & 0xF0) == 0xE0 && (next_byte & 0xC0) == 0x80 &&
(i + 2 < length) &&
(static_cast<uint8_t>(str[i + 2]) & 0xC0) ==
0x80) { // 3-byte sequence
i += 3;
} else if (
(byte & 0xF8) == 0xF0 && (next_byte & 0xC0) == 0x80 &&
(i + 2 < length) &&
(static_cast<uint8_t>(str[i + 2]) & 0xC0) == 0x80 &&
(i + 3 < length) &&
(static_cast<uint8_t>(str[i + 3]) & 0xC0) ==
0x80) { // 4-byte sequence
i += 4;
} else {
return false; // Invalid sequence
}
}
}
return true; // All bytes were valid
}
Comment on lines +35 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this util be used by runner as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering whether we want to bring this to runner. It's only causing problem for Java though. For C++ runner, and iOS, seems that they use uint8_t as string, and they can print partial string, so won't cause this issue.


std::string token_buffer;
} // namespace

namespace executorch_jni {

class ExecuTorchLlamaCallbackJni
Expand All @@ -45,6 +80,15 @@ class ExecuTorchLlamaCallbackJni
static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
static const auto method =
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onResult");

token_buffer += result;
if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) {
ET_LOG(
Info, "Current token buffer is not valid UTF-8. Waiting for more.");
return;
}
result = token_buffer;
token_buffer = "";
facebook::jni::local_ref<jstring> s = facebook::jni::make_jstring(result);
method(self(), s);
}
Expand Down
Loading