@@ -61,14 +61,16 @@ Runner::Runner(
61
61
const std::string& kv_updater,
62
62
const int ngram,
63
63
const int window,
64
- const int gcap)
65
- : tokenizer_path_(tokenizer_path),
64
+ const int gcap,
65
+ std::unique_ptr<tokenizers::Tokenizer> tokenizer)
66
+ : ngram_(ngram),
67
+ window_ (window),
68
+ gcap_(gcap),
69
+ tokenizer_path_(tokenizer_path),
66
70
performance_output_path_(performance_output_path),
67
71
temperature_(temperature),
68
72
eval_mode_(static_cast <EvalMode>(eval_mode)),
69
- ngram_(ngram),
70
- window_(window),
71
- gcap_(gcap) {
73
+ tokenizer_(std::move(tokenizer)) {
72
74
module_ = std::make_unique<Module>(
73
75
model_path, Module::LoadMode::MmapUseMlockIgnoreErrors);
74
76
stats_.reset ();
@@ -115,30 +117,40 @@ Error Runner::load() {
115
117
break ;
116
118
}
117
119
118
- // load tokenizer. Assuming tiktoken is the default tokenizer
119
- tokenizer_ = get_tiktoken_for_llama ();
120
- auto err = tokenizer_->load (tokenizer_path_);
121
120
auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>();
122
- // Rely on tiktoken to throw error if the artifact is incompatible. Then we
123
- // fallback to BPE tokenizer.
124
- if (err != tokenizers::Error::Ok) {
125
- ET_LOG (
126
- Info,
127
- " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
128
- tokenizer_path_.c_str ());
129
- tokenizer_.reset ();
130
- tokenizer_ = std::make_unique<tokenizers::Llama2cTokenizer>();
131
- err = tokenizer_->load (tokenizer_path_);
132
- llama_version_ = LlamaVersion::kLlama2 ;
133
- ET_CHECK_MSG (
134
- err == tokenizers::Error::Ok,
135
- " failed to load tokenizer %s" ,
136
- tokenizer_path_.c_str ());
137
- } else {
121
+ // TODO: remove this once we could release the new tokens used for the
122
+ // tokenizer
123
+ if (tokenizer_ != nullptr ) {
138
124
eos_ids->insert (tokenizer_->encode (" <|eot_id|>" , 0 , 0 ).get ()[0 ]);
139
- llama_version_ = LlamaVersion::kLlama3 ;
125
+ eos_ids->insert (tokenizer_->encode (" <|eot|>" , 0 , 0 ).get ()[0 ]);
126
+ eos_ids->insert (tokenizer_->encode (" <|end_of_text|>" , 0 , 0 ).get ()[0 ]);
127
+ } else {
128
+ // load tokenizer. Assuming tiktoken is the default tokenizer
129
+ tokenizer_ = get_tiktoken_for_llama ();
130
+ auto err = tokenizer_->load (tokenizer_path_);
131
+ auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>();
132
+ // Rely on tiktoken to throw error if the artifact is incompatible. Then we
133
+ // fallback to BPE tokenizer.
134
+ if (err != tokenizers::Error::Ok) {
135
+ ET_LOG (
136
+ Info,
137
+ " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
138
+ tokenizer_path_.c_str ());
139
+ tokenizer_.reset ();
140
+ tokenizer_ = std::make_unique<tokenizers::Llama2cTokenizer>();
141
+ err = tokenizer_->load (tokenizer_path_);
142
+ llama_version_ = LlamaVersion::kLlama2 ;
143
+ ET_CHECK_MSG (
144
+ err == tokenizers::Error::Ok,
145
+ " failed to load tokenizer %s" ,
146
+ tokenizer_path_.c_str ());
147
+ } else {
148
+ eos_ids->insert (tokenizer_->encode (" <|eot_id|>" , 0 , 0 ).get ()[0 ]);
149
+ llama_version_ = LlamaVersion::kLlama3 ;
150
+ }
151
+ eos_ids->insert (tokenizer_->eos_tok ());
140
152
}
141
- eos_ids-> insert (tokenizer_-> eos_tok ());
153
+
142
154
int32_t vocab_size = tokenizer_->vocab_size ();
143
155
decoder_runner_ =
144
156
std::make_unique<DecoderRunner>(module_.get (), vocab_size, temperature_);
0 commit comments