Skip to content

Commit 79b97e4

Browse files
authored
[tokenizer] Consolidate how runner decide which tokenizer to use
Differential Revision: D62160344 Pull Request resolved: #5052
1 parent 3716680 commit 79b97e4

File tree

7 files changed

+134
-47
lines changed

7 files changed

+134
-47
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,19 @@ Error Runner::load() {
6969
return Error::Ok;
7070
}
7171
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
72-
// load tokenizer
72+
// load tokenizer. Assuming tiktoken is the default tokenizer
7373
tokenizer_ = nullptr;
74-
tokenizer_ = std::make_unique<BPETokenizer>();
74+
tokenizer_ = get_tiktoken_for_llama();
7575
Error err = tokenizer_->load(tokenizer_path_);
76+
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
77+
// fallback to BPE tokenizer.
7678
if (err == Error::InvalidArgument) {
7779
ET_LOG(
7880
Info,
79-
"Failed to load %s as a BPETokenizer artifact, trying Tiktoken",
81+
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
8082
tokenizer_path_.c_str());
8183
tokenizer_.reset();
82-
tokenizer_ = get_tiktoken_for_llama();
84+
tokenizer_ = std::make_unique<BPETokenizer>();
8385
tokenizer_->load(tokenizer_path_);
8486
}
8587

extension/llm/tokenizer/base64.h

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
#pragma once
2626

27+
#include <executorch/runtime/core/error.h>
28+
#include <executorch/runtime/core/result.h>
2729
#include <executorch/runtime/platform/assert.h>
2830
#include <cassert>
2931
#include <string>
@@ -32,10 +34,13 @@
3234
namespace executorch {
3335
namespace extension {
3436
namespace llm {
37+
using Error = executorch::runtime::Error;
38+
template <typename T>
39+
using Result = executorch::runtime::Result<T>;
3540

3641
namespace base64 {
3742

38-
std::string decode(const std::string_view& input);
43+
Result<std::string> decode(const std::string_view& input);
3944

4045
namespace detail {
4146

@@ -59,118 +64,135 @@ constexpr uint32_t DECODE_TABLE[] = {
5964
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
6065
255};
6166

62-
inline void validate(uint32_t v) {
63-
ET_CHECK_MSG(v != 255, "invalid char");
67+
inline Error validate(uint32_t v) {
68+
ET_CHECK_OR_RETURN_ERROR(v != 255, InvalidArgument, "invalid char");
69+
return Error::Ok;
6470
}
6571

66-
inline void decode(const std::string_view& input, std::string& output) {
67-
ET_CHECK_MSG(
68-
input.size() == 4, "input length must be 4, got %zu", input.size());
72+
inline Error decode(const std::string_view& input, std::string& output) {
73+
ET_CHECK_OR_RETURN_ERROR(
74+
input.size() == 4,
75+
InvalidArgument,
76+
"input length must be 4, got %zu",
77+
input.size());
6978

7079
uint32_t val = 0;
7180

7281
uint8_t c = input[0];
7382
auto v = DECODE_TABLE[c];
74-
validate(v);
83+
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
7584
val = v;
7685

7786
c = input[1];
7887
v = DECODE_TABLE[c];
79-
validate(v);
88+
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
8089
val = (val << 6) | v;
8190

8291
c = input[2];
8392
v = DECODE_TABLE[c];
84-
validate(v);
93+
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
8594
val = (val << 6) | v;
8695

8796
c = input[3];
8897
v = DECODE_TABLE[c];
89-
validate(v);
98+
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
9099
val = (val << 6) | v;
91100

92101
output.push_back(static_cast<char>((val >> 16) & 0xFF));
93102
output.push_back(static_cast<char>((val >> 8) & 0xFF));
94103
output.push_back(static_cast<char>(val & 0xFF));
104+
return Error::Ok;
95105
}
96106

97-
inline void decode_1_padding(
107+
inline Error decode_1_padding(
98108
const std::string_view& input,
99109
std::string& output) {
100-
ET_CHECK_MSG(
101-
input.size() == 3, "input length must be 3, got %zu", input.size());
110+
ET_CHECK_OR_RETURN_ERROR(
111+
input.size() == 3,
112+
InvalidArgument,
113+
"input length must be 3, got %zu",
114+
input.size());
102115

103116
uint32_t val = 0;
104117

105118
uint8_t c = input[0];
106119
auto v = DECODE_TABLE[c];
107-
validate(v);
120+
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
108121
val = v;
109122

110123
c = input[1];
111124
v = DECODE_TABLE[c];
112-
validate(v);
125+
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
113126
val = (val << 6) | v;
114127

115128
c = input[2];
116129
v = DECODE_TABLE[c];
117-
validate(v);
130+
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
118131
val = (val << 6) | v;
119132

120133
output.push_back(static_cast<char>((val >> 10) & 0xFF));
121134
output.push_back(static_cast<char>((val >> 2) & 0xFF));
135+
return Error::Ok;
122136
}
123137

124-
inline void decode_2_padding(
138+
inline Error decode_2_padding(
125139
const std::string_view& input,
126140
std::string& output) {
127-
assert(input.size() == 2);
141+
ET_CHECK_OR_RETURN_ERROR(
142+
input.size() == 2,
143+
InvalidArgument,
144+
"input length must be 2, got %zu",
145+
input.size());
128146

129147
uint32_t val = 0;
130148

131149
uint8_t c = input[0];
132150
auto v = DECODE_TABLE[c];
133-
validate(v);
151+
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
134152
val = v;
135153

136154
c = input[1];
137155
v = DECODE_TABLE[c];
138-
validate(v);
156+
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
139157
val = (val << 6) | v;
140158

141159
output.push_back(static_cast<char>((val >> 4) & 0xFF));
160+
return Error::Ok;
142161
}
143162

144163
} // namespace detail
145164

146-
inline std::string decode(const std::string_view& input) {
147-
ET_CHECK_MSG(!input.empty(), "empty input");
165+
inline Result<std::string> decode(const std::string_view& input) {
166+
ET_CHECK_OR_RETURN_ERROR(!input.empty(), InvalidArgument, "empty input");
148167

149168
// Faster than `input.size() % 4`.
150-
ET_CHECK_MSG(
169+
ET_CHECK_OR_RETURN_ERROR(
151170
(input.size() & 3) == 0 && input.size() >= 4,
171+
InvalidArgument,
152172
"input length must be larger than 4 and is multiple of 4, got %zu",
153173
input.size());
154174

155175
std::string output;
156176
output.reserve(input.size() / 4 * 3);
157177
auto idx = 0U;
158178
for (; idx < input.size() - 4; idx += 4) {
159-
detail::decode(input.substr(idx, 4), output);
179+
ET_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output));
160180
}
161181

162182
// Last 4 bytes. Might contain paddings.
163183
if (input[idx + 3] == '=') {
164184
if (input[idx + 2] == '=') {
165185
// Tow paddings.
166-
detail::decode_2_padding(input.substr(idx, 2), output);
186+
ET_CHECK_OK_OR_RETURN_ERROR(
187+
detail::decode_2_padding(input.substr(idx, 2), output));
167188
} else {
168189
// One padding.
169-
detail::decode_1_padding(input.substr(idx, 3), output);
190+
ET_CHECK_OK_OR_RETURN_ERROR(
191+
detail::decode_1_padding(input.substr(idx, 3), output));
170192
}
171193
} else {
172194
// No padding.
173-
detail::decode(input.substr(idx, 4), output);
195+
ET_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output));
174196
}
175197

176198
return output;
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tet 0
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ICAgICAgIA== 18446744073709551616
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ICAgICAgIA==10

extension/llm/tokenizer/test/test_tiktoken.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
#include <executorch/extension/llm/tokenizer/tiktoken.h>
1010
#include <executorch/runtime/platform/runtime.h>
11+
#include <gmock/gmock.h>
1112
#include <gtest/gtest.h>
13+
#include <sstream>
1214
#include <vector>
1315

1416
using namespace ::testing;
@@ -140,3 +142,47 @@ TEST_F(TiktokenExtensionTest, ConstructionWithInvalidEOSIndex) {
140142
"");
141143
#endif
142144
}
145+
146+
TEST_F(TiktokenExtensionTest, LoadWithInvalidPath) {
147+
auto invalidModelPath =
148+
std::getenv("RESOURCES_PATH") + std::string("/nonexistent.model");
149+
150+
Error res = tokenizer_->load(invalidModelPath.c_str());
151+
EXPECT_EQ(res, Error::InvalidArgument);
152+
}
153+
154+
TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidRank) {
155+
auto invalidModelPath = std::getenv("RESOURCES_PATH") +
156+
std::string("/test_tiktoken_invalid_rank.model");
157+
158+
Error res = tokenizer_->load(invalidModelPath.c_str());
159+
160+
EXPECT_EQ(res, Error::InvalidArgument);
161+
}
162+
163+
TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidBase64) {
164+
auto invalidModelPath = std::getenv("RESOURCES_PATH") +
165+
std::string("/test_tiktoken_invalid_base64.model");
166+
167+
Error res = tokenizer_->load(invalidModelPath.c_str());
168+
169+
EXPECT_EQ(res, Error::InvalidArgument);
170+
}
171+
172+
TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithNoSpace) {
173+
auto invalidModelPath = std::getenv("RESOURCES_PATH") +
174+
std::string("/test_tiktoken_no_space.model");
175+
176+
Error res = tokenizer_->load(invalidModelPath.c_str());
177+
178+
EXPECT_EQ(res, Error::InvalidArgument);
179+
}
180+
181+
TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithBPEFile) {
182+
auto invalidModelPath =
183+
std::getenv("RESOURCES_PATH") + std::string("/test_bpe_tokenizer.bin");
184+
185+
Error res = tokenizer_->load(invalidModelPath.c_str());
186+
187+
EXPECT_EQ(res, Error::InvalidArgument);
188+
}

extension/llm/tokenizer/tiktoken.cpp

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include <executorch/extension/llm/tokenizer/base64.h>
2929
#include <executorch/extension/llm/tokenizer/tiktoken.h>
30+
#include <executorch/runtime/core/result.h>
3031
#include <fstream>
3132
#include <limits>
3233

@@ -65,47 +66,60 @@ static Re2UPtr _build_special_token_regex(const Encoder& special_encoder) {
6566
return _create_regex(special_pattern);
6667
}
6768

68-
static std::pair<std::string, uint64_t> _parse(const std::string& line) {
69+
static Result<std::pair<std::string, uint64_t>> _parse(
70+
const std::string& line) {
71+
// Tiktoken format
72+
// https://github.com/openai/tiktoken/blob/main/tiktoken/load.py#L140 <base64
73+
// encoded token str> <rank>
6974
auto pos = line.find(" ");
70-
ET_CHECK_MSG(
71-
pos != std::string::npos, "invalid encoder line: %s", line.c_str());
75+
ET_CHECK_OR_RETURN_ERROR(
76+
pos != std::string::npos,
77+
InvalidArgument,
78+
"invalid tiktoken line: %s",
79+
line.c_str());
7280

73-
auto token = base64::decode({line.data(), pos});
81+
auto token = ET_UNWRAP(base64::decode({line.data(), pos}));
7482
uint64_t rank = 0;
7583
try {
7684
rank = std::stoul(line.substr(pos + 1));
7785
} catch (const std::exception&) {
78-
ET_CHECK_MSG(false, "invalid encoder rank: %s", line.c_str());
86+
ET_CHECK_OR_RETURN_ERROR(
87+
false, InvalidArgument, "invalid encoder rank: %s", line.c_str());
7988
}
8089

81-
return {std::move(token), rank};
90+
return std::pair{std::move(token), rank};
8291
}
8392

84-
static Encoder _load_encoder(const std::string& path) {
93+
static Result<Encoder> _load_encoder(const std::string& path) {
8594
std::ifstream file(path);
86-
ET_CHECK_MSG(file, "failed to open encoder file: %s", path.c_str());
95+
ET_CHECK_OR_RETURN_ERROR(
96+
file, InvalidArgument, "failed to open encoder file: %s", path.c_str());
8797

8898
Encoder encoder;
8999
std::string line;
90100
while (std::getline(file, line)) {
91-
auto [token, rank] = _parse(line);
101+
auto [token, rank] = ET_UNWRAP(_parse(line));
92102

93-
ET_CHECK_MSG(
103+
ET_CHECK_OR_RETURN_ERROR(
94104
encoder.emplace(std::move(token), rank).second,
105+
InvalidArgument,
95106
"duplicate item: %s",
96107
line.c_str());
97108
}
98109

99110
return encoder;
100111
}
101112

102-
static Decoder _build_decoder(const Encoder& encoder) {
113+
static Result<Decoder> _build_decoder(const Encoder& encoder) {
103114
Decoder decoder;
104115
for (const auto& [k, v] : encoder) {
105116
decoder.emplace(v, k);
106117
}
107118

108-
ET_CHECK_MSG(encoder.size() == decoder.size(), "duplicate items in encoder");
119+
ET_CHECK_OR_RETURN_ERROR(
120+
encoder.size() == decoder.size(),
121+
InvalidArgument,
122+
"duplicate items in encoder");
109123

110124
return decoder;
111125
}
@@ -356,11 +370,11 @@ Tiktoken::Tiktoken(
356370
}
357371

358372
Error Tiktoken::load(const std::string& path) {
359-
_encoder = _load_encoder(path);
373+
_encoder = ET_UNWRAP(_load_encoder(path));
360374
_special_token_encoder = _build_special_token_encoder(_encoder.size());
361375

362-
_decoder = _build_decoder(_encoder);
363-
_special_token_decoder = _build_decoder(_special_token_encoder);
376+
_decoder = ET_UNWRAP(_build_decoder(_encoder));
377+
_special_token_decoder = ET_UNWRAP(_build_decoder(_special_token_encoder));
364378

365379
_regex = _create_regex(_pattern);
366380
// Warmup re2 as it is slow on the first run, void the return value as it's
@@ -393,7 +407,7 @@ Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) const {
393407
for (auto i = 0; i < eos; ++i) {
394408
res.push_back(eos_tok_);
395409
}
396-
return Result(res);
410+
return Result<std::vector<uint64_t>>(std::move(res));
397411
}
398412

399413
Result<std::string> Tiktoken::decode(uint64_t prev, uint64_t cur) const {

0 commit comments

Comments
 (0)