|
16 | 16 | #include <executorch/extension/llm/runner/util.h>
|
17 | 17 |
|
18 | 18 | #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
|
19 |
| -#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h> |
| 19 | +#include <pytorch/tokenizers/llama2c_tokenizer.h> |
20 | 20 |
|
21 | 21 | namespace example {
|
22 | 22 |
|
@@ -78,17 +78,21 @@ Error Runner::load() {
|
78 | 78 | // load tokenizer. Assuming tiktoken is the default tokenizer
|
79 | 79 | tokenizer_ = nullptr;
|
80 | 80 | tokenizer_ = get_tiktoken_for_llama();
|
81 |
| - Error err = tokenizer_->load(tokenizer_path_); |
| 81 | + ::tokenizers::Error err = tokenizer_->load(tokenizer_path_); |
82 | 82 | // Rely on tiktoken to throw error if the artifact is incompatible. Then we
|
83 | 83 | // fallback to BPE tokenizer.
|
84 |
| - if (err == Error::InvalidArgument) { |
| 84 | + if (err != ::tokenizers::Error::Ok) { |
85 | 85 | ET_LOG(
|
86 | 86 | Info,
|
87 | 87 | "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
|
88 | 88 | tokenizer_path_.c_str());
|
89 | 89 | tokenizer_.reset();
|
90 |
| - tokenizer_ = std::make_unique<llm::BPETokenizer>(); |
91 |
| - tokenizer_->load(tokenizer_path_); |
| 90 | + tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>(); |
| 91 | + err = tokenizer_->load(tokenizer_path_); |
| 92 | + ET_CHECK_TK_OK_OR_RETURN_ERROR( |
| 93 | + err, |
| 94 | + "Failed to load %s as a llama2.c tokenizer artifact", |
| 95 | + tokenizer_path_.c_str()); |
92 | 96 | }
|
93 | 97 |
|
94 | 98 | ET_LOG(Info, "Reading metadata from model");
|
@@ -201,12 +205,12 @@ Error Runner::generate(
|
201 | 205 | ? seq_len
|
202 | 206 | : metadata_.at(kMaxSeqLen);
|
203 | 207 |
|
204 |
| - Result<std::vector<uint64_t>> encode_res = tokenizer_->encode( |
| 208 | + ::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode( |
205 | 209 | prompt,
|
206 | 210 | /* bos */ 0,
|
207 | 211 | /* eos */ 0);
|
208 | 212 |
|
209 |
| - ET_CHECK_OK_OR_RETURN_ERROR( |
| 213 | + ET_CHECK_TK_OK_OR_RETURN_ERROR( |
210 | 214 | encode_res.error(), "Failed to encode prompt %s", prompt.c_str());
|
211 | 215 |
|
212 | 216 | // encode the (string) prompt into tokens sequence
|
@@ -242,7 +246,8 @@ Error Runner::generate(
|
242 | 246 | uint64_t cur_token = prefill_res.get();
|
243 | 247 |
|
244 | 248 | // print the first token from prefill. No prev_token so use cur_token for it.
|
245 |
| - wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token))); |
| 249 | + wrapped_callback( |
| 250 | + ET_UNWRAP_TOKENIZER(tokenizer_->decode(cur_token, cur_token))); |
246 | 251 | RUNNER_ET_LOG(
|
247 | 252 | warmup,
|
248 | 253 | "RSS after prompt prefill: %f MiB (0 if unsupported)",
|
|
0 commit comments