|
17 | 17 |
|
18 | 18 | #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h> |
19 | 19 | #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h> |
| 20 | +#include <executorch/extension/llm/tokenizer/hf_tokenizer.h> |
20 | 21 |
|
21 | 22 | namespace example { |
22 | 23 |
|
@@ -75,20 +76,33 @@ Error Runner::load() { |
75 | 76 | return Error::Ok; |
76 | 77 | } |
77 | 78 | ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); |
78 | | - // load tokenizer. Assuming tiktoken is the default tokenizer |
| 79 | + // Load tokenizer. |
79 | 80 | tokenizer_ = nullptr; |
80 | | - tokenizer_ = get_tiktoken_for_llama(); |
81 | | - Error err = tokenizer_->load(tokenizer_path_); |
82 | | - // Rely on tiktoken to throw error if the artifact is incompatible. Then we |
83 | | - // fallback to BPE tokenizer. |
84 | | - if (err == Error::InvalidArgument) { |
85 | | - ET_LOG( |
86 | | - Info, |
87 | | - "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer", |
88 | | - tokenizer_path_.c_str()); |
89 | | - tokenizer_.reset(); |
90 | | - tokenizer_ = std::make_unique<llm::BPETokenizer>(); |
| 81 | + // Check if tokenizer_path_ ends with ".json". |
| 82 | + if (tokenizer_path_.size() >= 5 && |
| 83 | + tokenizer_path_.compare(tokenizer_path_.size() - 5, 5, ".json") == 0) { |
| 84 | + tokenizer_ = std::make_unique<llm::HfTokenizer>(); |
91 | 85 | tokenizer_->load(tokenizer_path_); |
| 86 | + ET_LOG( |
| 87 | + Info, "Loaded tokenizer %s as HF tokenizer", tokenizer_path_.c_str()); |
| 88 | + } else { |
| 89 | + // Else assume TikToken is the default tokenizer, using BPE as a fallback. |
| 90 | + tokenizer_ = get_tiktoken_for_llama(); |
| 91 | + Error err = tokenizer_->load(tokenizer_path_); |
| 92 | + if (err == Error::InvalidArgument) { |
| 93 | + tokenizer_.reset(); |
| 94 | + tokenizer_ = std::make_unique<llm::BPETokenizer>(); |
| 95 | + tokenizer_->load(tokenizer_path_); |
| 96 | + ET_LOG( |
| 97 | + Info, |
| 98 | + "Loaded tokenizer %s as BPE tokenizer", |
| 99 | + tokenizer_path_.c_str()); |
| 100 | + } else { |
| 101 | + ET_LOG( |
| 102 | + Info, |
| 103 | + "Loaded tokenizer %s as TikToken tokenizer", |
| 104 | + tokenizer_path_.c_str()); |
| 105 | + } |
92 | 106 | } |
93 | 107 |
|
94 | 108 | ET_LOG(Info, "Reading metadata from model"); |
|
0 commit comments