@@ -61,14 +61,16 @@ Runner::Runner(
6161 const std::string& kv_updater,
6262 const int ngram,
6363 const int window,
64- const int gcap)
65- : tokenizer_path_(tokenizer_path),
64+ const int gcap,
65+ std::unique_ptr<tokenizers::Tokenizer> tokenizer)
66+ : ngram_(ngram),
67+ window_ (window),
68+ gcap_(gcap),
69+ tokenizer_path_(tokenizer_path),
6670 performance_output_path_(performance_output_path),
6771 temperature_(temperature),
6872 eval_mode_(static_cast <EvalMode>(eval_mode)),
69- ngram_(ngram),
70- window_(window),
71- gcap_(gcap) {
73+ tokenizer_(std::move(tokenizer)) {
7274 module_ = std::make_unique<Module>(
7375 model_path, Module::LoadMode::MmapUseMlockIgnoreErrors);
7476 stats_.reset ();
@@ -115,30 +117,38 @@ Error Runner::load() {
115117 break ;
116118 }
117119
118- // load tokenizer. Assuming tiktoken is the default tokenizer
119- tokenizer_ = get_tiktoken_for_llama ();
120- auto err = tokenizer_->load (tokenizer_path_);
121120 auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>();
122- // Rely on tiktoken to throw error if the artifact is incompatible. Then we
123- // fallback to BPE tokenizer.
124- if (err != tokenizers::Error::Ok) {
125- ET_LOG (
126- Info,
127- " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
128- tokenizer_path_.c_str ());
129- tokenizer_.reset ();
130- tokenizer_ = std::make_unique<tokenizers::Llama2cTokenizer>();
131- err = tokenizer_->load (tokenizer_path_);
132- llama_version_ = LlamaVersion::kLlama2 ;
133- ET_CHECK_MSG (
134- err == tokenizers::Error::Ok,
135- " failed to load tokenizer %s" ,
136- tokenizer_path_.c_str ());
137- } else {
121+ if (tokenizer_ != nullptr ) {
138122 eos_ids->insert (tokenizer_->encode (" <|eot_id|>" , 0 , 0 ).get ()[0 ]);
139- llama_version_ = LlamaVersion::kLlama3 ;
123+ eos_ids->insert (tokenizer_->encode (" <|eot|>" , 0 , 0 ).get ()[0 ]);
124+ eos_ids->insert (tokenizer_->encode (" <|end_of_text|>" , 0 , 0 ).get ()[0 ]);
125+ } else {
126+ // load tokenizer. Assuming tiktoken is the default tokenizer
127+ tokenizer_ = get_tiktoken_for_llama ();
128+ auto err = tokenizer_->load (tokenizer_path_);
129+ auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>();
130+ // Rely on tiktoken to throw error if the artifact is incompatible. Then we
131+ // fallback to BPE tokenizer.
132+ if (err != tokenizers::Error::Ok) {
133+ ET_LOG (
134+ Info,
135+ " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
136+ tokenizer_path_.c_str ());
137+ tokenizer_.reset ();
138+ tokenizer_ = std::make_unique<tokenizers::Llama2cTokenizer>();
139+ err = tokenizer_->load (tokenizer_path_);
140+ llama_version_ = LlamaVersion::kLlama2 ;
141+ ET_CHECK_MSG (
142+ err == tokenizers::Error::Ok,
143+ " failed to load tokenizer %s" ,
144+ tokenizer_path_.c_str ());
145+ } else {
146+ eos_ids->insert (tokenizer_->encode (" <|eot_id|>" , 0 , 0 ).get ()[0 ]);
147+ llama_version_ = LlamaVersion::kLlama3 ;
148+ }
149+ eos_ids->insert (tokenizer_->eos_tok ());
140150 }
141- eos_ids-> insert (tokenizer_-> eos_tok ());
151+
142152 int32_t vocab_size = tokenizer_->vocab_size ();
143153 decoder_runner_ =
144154 std::make_unique<DecoderRunner>(module_.get (), vocab_size, temperature_);
0 commit comments