@@ -61,14 +61,16 @@ Runner::Runner(
6161 const std::string& kv_updater,
6262 const int ngram,
6363 const int window,
64- const int gcap)
65- : tokenizer_path_(tokenizer_path),
64+ const int gcap,
65+ std::unique_ptr<tokenizers::Tokenizer> tokenizer)
66+ : ngram_(ngram),
67+ window_ (window),
68+ gcap_(gcap),
69+ tokenizer_path_(tokenizer_path),
6670 performance_output_path_(performance_output_path),
6771 temperature_(temperature),
6872 eval_mode_(static_cast <EvalMode>(eval_mode)),
69- ngram_(ngram),
70- window_(window),
71- gcap_(gcap) {
73+ tokenizer_(std::move(tokenizer)) {
7274 module_ = std::make_unique<Module>(
7375 model_path, Module::LoadMode::MmapUseMlockIgnoreErrors);
7476 stats_.reset ();
@@ -115,30 +117,39 @@ Error Runner::load() {
115117 break ;
116118 }
117119
118- // load tokenizer. Assuming tiktoken is the default tokenizer
119- tokenizer_ = get_tiktoken_for_llama ();
120- auto err = tokenizer_->load (tokenizer_path_);
121120 auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>();
122- // Rely on tiktoken to throw error if the artifact is incompatible. Then we
123- // fallback to BPE tokenizer.
124- if (err != tokenizers::Error::Ok) {
125- ET_LOG (
126- Info,
127- " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
128- tokenizer_path_.c_str ());
129- tokenizer_.reset ();
130- tokenizer_ = std::make_unique<tokenizers::Llama2cTokenizer>();
131- err = tokenizer_->load (tokenizer_path_);
132- llama_version_ = LlamaVersion::kLlama2 ;
133- ET_CHECK_MSG (
134- err == tokenizers::Error::Ok,
135- " failed to load tokenizer %s" ,
136- tokenizer_path_.c_str ());
137- } else {
121+ // TODO: remove this once we could release the new tokens used for the tokenizer
122+ if (tokenizer_ != nullptr ) {
138123 eos_ids->insert (tokenizer_->encode (" <|eot_id|>" , 0 , 0 ).get ()[0 ]);
139- llama_version_ = LlamaVersion::kLlama3 ;
124+ eos_ids->insert (tokenizer_->encode (" <|eot|>" , 0 , 0 ).get ()[0 ]);
125+ eos_ids->insert (tokenizer_->encode (" <|end_of_text|>" , 0 , 0 ).get ()[0 ]);
126+ } else {
127+ // load tokenizer. Assuming tiktoken is the default tokenizer
128+ tokenizer_ = get_tiktoken_for_llama ();
129+ auto err = tokenizer_->load (tokenizer_path_);
130+ auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>();
131+ // Rely on tiktoken to throw error if the artifact is incompatible. Then we
132+ // fallback to BPE tokenizer.
133+ if (err != tokenizers::Error::Ok) {
134+ ET_LOG (
135+ Info,
136+ " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
137+ tokenizer_path_.c_str ());
138+ tokenizer_.reset ();
139+ tokenizer_ = std::make_unique<tokenizers::Llama2cTokenizer>();
140+ err = tokenizer_->load (tokenizer_path_);
141+ llama_version_ = LlamaVersion::kLlama2 ;
142+ ET_CHECK_MSG (
143+ err == tokenizers::Error::Ok,
144+ " failed to load tokenizer %s" ,
145+ tokenizer_path_.c_str ());
146+ } else {
147+ eos_ids->insert (tokenizer_->encode (" <|eot_id|>" , 0 , 0 ).get ()[0 ]);
148+ llama_version_ = LlamaVersion::kLlama3 ;
149+ }
150+ eos_ids->insert (tokenizer_->eos_tok ());
140151 }
141- eos_ids-> insert (tokenizer_-> eos_tok ());
152+
142153 int32_t vocab_size = tokenizer_->vocab_size ();
143154 decoder_runner_ =
144155 std::make_unique<DecoderRunner>(module_.get (), vocab_size, temperature_);
0 commit comments