1616#include < executorch/extension/llm/runner/util.h>
1717
1818#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
19- #include < pytorch/tokenizers/llama2c_tokenizer.h>
2019#include < pytorch/tokenizers/hf_tokenizer.h>
20+ #include < pytorch/tokenizers/llama2c_tokenizer.h>
2121
2222namespace example {
2323
@@ -36,6 +36,41 @@ static constexpr auto kMaxContextLen = "get_max_context_len";
3636static constexpr auto kVocabSize = " get_vocab_size" ;
3737static constexpr auto kUseKVCache = " use_kv_cache" ;
3838static constexpr auto kUseSDPAWithKVCache = " use_sdpa_with_kv_cache" ;
39+
40+ std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer (
41+ const std::string& tokenizer_path) {
42+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer = nullptr ;
43+ ::tokenizers::Error err;
44+
45+ // First try to load as a json tokenizer.
46+ {
47+ auto tokenizer = std::make_unique<tokenizers::HFTokenizer>();
48+ if (tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
49+ ET_LOG (Info, " Loaded json tokenizer" );
50+ return tokenizer;
51+ }
52+ }
53+
54+ // Try to load as tiktoken tokenizer.
55+ {
56+ auto tokenizer = get_tiktoken_for_llama ();
57+ if (tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
58+ ET_LOG (Info, " Loaded TikToken tokenizer" );
59+ return tokenizer;
60+ }
61+ }
62+
63+ // Try to load as BPE tokenizer.
64+ {
65+ auto tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
66+ if (tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
67+ ET_LOG (Info, " Loaded BPE tokenizer" );
68+ return tokenizer;
69+ }
70+ }
71+
72+ return nullptr ;
73+ }
3974} // namespace
4075
4176Runner::Runner (
@@ -78,35 +113,15 @@ Error Runner::load() {
78113 return Error::Ok;
79114 }
80115 ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
116+
81117 // Load tokenizer.
82- tokenizer_ = nullptr ;
83- // Check if tokenizer_path_ ends with ".json".
84- if (tokenizer_path_.size () >= 5 &&
85-
86- tokenizer_path_.compare (tokenizer_path_.size () - 5 , 5 , " .json" ) == 0 ) {
87- tokenizer_ = std::make_unique<tokenizers::HFTokenizer>();
88- ET_LOG (Info, " Loading json tokenizer" );
89- tokenizer_->load (tokenizer_path_);
118+ tokenizer_ = load_tokenizer (tokenizer_path_);
119+ if (tokenizer_ == nullptr ) {
90120 ET_LOG (
91- Info, " Loaded tokenizer %s as HF tokenizer" , tokenizer_path_.c_str ());
92- } else {
93- ::tokenizers::Error err = tokenizer_->load (tokenizer_path_);
94- tokenizer_ = get_tiktoken_for_llama ();
95- // Rely on tiktoken to throw error if the artifact is incompatible. Then we
96- // fallback to BPE tokenizer.
97- if (err != ::tokenizers::Error::Ok) {
98- ET_LOG (
99- Info,
100- " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
101- tokenizer_path_.c_str ());
102- tokenizer_.reset ();
103- tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
104- err = tokenizer_->load (tokenizer_path_);
105- ET_CHECK_TK_OK_OR_RETURN_ERROR (
106- err,
107- " Failed to load %s as a llama2.c tokenizer artifact" ,
108- tokenizer_path_.c_str ());
109- }
121+ Error,
122+ " Failed to load %s as a llama2.c tokenizer artifact" ,
123+ tokenizer_path_.c_str ());
124+ return ::executorch::runtime::Error::InvalidArgument;
110125 }
111126
112127 ET_LOG (Info, " Reading metadata from model" );
0 commit comments