@@ -61,14 +61,16 @@ Runner::Runner(
6161 const std::string& kv_updater,
6262 const int ngram,
6363 const int window,
64- const int gcap)
65- : tokenizer_path_(tokenizer_path),
64+ const int gcap,
65+ std::unique_ptr<tokenizers::Tokenizer> tokenizer)
66+ : ngram_(ngram),
67+ window_ (window),
68+ gcap_(gcap),
69+ tokenizer_path_(tokenizer_path),
6670 performance_output_path_(performance_output_path),
6771 temperature_(temperature),
6872 eval_mode_(static_cast <EvalMode>(eval_mode)),
69- ngram_(ngram),
70- window_(window),
71- gcap_(gcap) {
73+ tokenizer_(std::move(tokenizer)) {
7274 module_ = std::make_unique<Module>(
7375 model_path, Module::LoadMode::MmapUseMlockIgnoreErrors);
7476 stats_.reset ();
@@ -115,30 +117,40 @@ Error Runner::load() {
115117 break ;
116118 }
117119
118- // load tokenizer. Assuming tiktoken is the default tokenizer
119- tokenizer_ = get_tiktoken_for_llama ();
120- auto err = tokenizer_->load (tokenizer_path_);
121120 auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>();
122- // Rely on tiktoken to throw error if the artifact is incompatible. Then we
123- // fallback to BPE tokenizer.
124- if (err != tokenizers::Error::Ok) {
125- ET_LOG (
126- Info,
127- " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
128- tokenizer_path_.c_str ());
129- tokenizer_.reset ();
130- tokenizer_ = std::make_unique<tokenizers::Llama2cTokenizer>();
131- err = tokenizer_->load (tokenizer_path_);
132- llama_version_ = LlamaVersion::kLlama2 ;
133- ET_CHECK_MSG (
134- err == tokenizers::Error::Ok,
135- " failed to load tokenizer %s" ,
136- tokenizer_path_.c_str ());
137- } else {
121+ // TODO: remove this once we could release the new tokens used for the
122+ // tokenizer
123+ if (tokenizer_ != nullptr ) {
138124 eos_ids->insert (tokenizer_->encode (" <|eot_id|>" , 0 , 0 ).get ()[0 ]);
139- llama_version_ = LlamaVersion::kLlama3 ;
125+ eos_ids->insert (tokenizer_->encode (" <|eot|>" , 0 , 0 ).get ()[0 ]);
126+ eos_ids->insert (tokenizer_->encode (" <|end_of_text|>" , 0 , 0 ).get ()[0 ]);
127+ } else {
128+ // load tokenizer. Assuming tiktoken is the default tokenizer
129+ tokenizer_ = get_tiktoken_for_llama ();
130+ auto err = tokenizer_->load (tokenizer_path_);
131+ auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>();
132+ // Rely on tiktoken to throw error if the artifact is incompatible. Then we
133+ // fallback to BPE tokenizer.
134+ if (err != tokenizers::Error::Ok) {
135+ ET_LOG (
136+ Info,
137+ " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
138+ tokenizer_path_.c_str ());
139+ tokenizer_.reset ();
140+ tokenizer_ = std::make_unique<tokenizers::Llama2cTokenizer>();
141+ err = tokenizer_->load (tokenizer_path_);
142+ llama_version_ = LlamaVersion::kLlama2 ;
143+ ET_CHECK_MSG (
144+ err == tokenizers::Error::Ok,
145+ " failed to load tokenizer %s" ,
146+ tokenizer_path_.c_str ());
147+ } else {
148+ eos_ids->insert (tokenizer_->encode (" <|eot_id|>" , 0 , 0 ).get ()[0 ]);
149+ llama_version_ = LlamaVersion::kLlama3 ;
150+ }
151+ eos_ids->insert (tokenizer_->eos_tok ());
140152 }
141- eos_ids-> insert (tokenizer_-> eos_tok ());
153+
142154 int32_t vocab_size = tokenizer_->vocab_size ();
143155 decoder_runner_ =
144156 std::make_unique<DecoderRunner>(module_.get (), vocab_size, temperature_);
0 commit comments