Skip to content

Commit f2f2a9d

Browse files
authored
Add a parameter to pass tokenizer to llama QNN runner
Differential Revision: D77910880 Pull Request resolved: #12285
1 parent 8611b23 commit f2f2a9d

File tree

3 files changed

+43
-27
lines changed

3 files changed

+43
-27
lines changed

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,16 @@ Runner::Runner(
6161
const std::string& kv_updater,
6262
const int ngram,
6363
const int window,
64-
const int gcap)
65-
: tokenizer_path_(tokenizer_path),
64+
const int gcap,
65+
std::unique_ptr<tokenizers::Tokenizer> tokenizer)
66+
: ngram_(ngram),
67+
window_(window),
68+
gcap_(gcap),
69+
tokenizer_path_(tokenizer_path),
6670
performance_output_path_(performance_output_path),
6771
temperature_(temperature),
6872
eval_mode_(static_cast<EvalMode>(eval_mode)),
69-
ngram_(ngram),
70-
window_(window),
71-
gcap_(gcap) {
73+
tokenizer_(std::move(tokenizer)) {
7274
module_ = std::make_unique<Module>(
7375
model_path, Module::LoadMode::MmapUseMlockIgnoreErrors);
7476
stats_.reset();
@@ -115,30 +117,40 @@ Error Runner::load() {
115117
break;
116118
}
117119

118-
// load tokenizer. Assuming tiktoken is the default tokenizer
119-
tokenizer_ = get_tiktoken_for_llama();
120-
auto err = tokenizer_->load(tokenizer_path_);
121120
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>();
122-
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
123-
// fallback to BPE tokenizer.
124-
if (err != tokenizers::Error::Ok) {
125-
ET_LOG(
126-
Info,
127-
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
128-
tokenizer_path_.c_str());
129-
tokenizer_.reset();
130-
tokenizer_ = std::make_unique<tokenizers::Llama2cTokenizer>();
131-
err = tokenizer_->load(tokenizer_path_);
132-
llama_version_ = LlamaVersion::kLlama2;
133-
ET_CHECK_MSG(
134-
err == tokenizers::Error::Ok,
135-
"failed to load tokenizer %s",
136-
tokenizer_path_.c_str());
137-
} else {
121+
// TODO: remove this once we could release the new tokens used for the
122+
// tokenizer
123+
if (tokenizer_ != nullptr) {
138124
eos_ids->insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]);
139-
llama_version_ = LlamaVersion::kLlama3;
125+
eos_ids->insert(tokenizer_->encode("<|eot|>", 0, 0).get()[0]);
126+
eos_ids->insert(tokenizer_->encode("<|end_of_text|>", 0, 0).get()[0]);
127+
} else {
128+
// load tokenizer. Assuming tiktoken is the default tokenizer
129+
tokenizer_ = get_tiktoken_for_llama();
130+
auto err = tokenizer_->load(tokenizer_path_);
131+
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>();
132+
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
133+
// fallback to BPE tokenizer.
134+
if (err != tokenizers::Error::Ok) {
135+
ET_LOG(
136+
Info,
137+
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
138+
tokenizer_path_.c_str());
139+
tokenizer_.reset();
140+
tokenizer_ = std::make_unique<tokenizers::Llama2cTokenizer>();
141+
err = tokenizer_->load(tokenizer_path_);
142+
llama_version_ = LlamaVersion::kLlama2;
143+
ET_CHECK_MSG(
144+
err == tokenizers::Error::Ok,
145+
"failed to load tokenizer %s",
146+
tokenizer_path_.c_str());
147+
} else {
148+
eos_ids->insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]);
149+
llama_version_ = LlamaVersion::kLlama3;
150+
}
151+
eos_ids->insert(tokenizer_->eos_tok());
140152
}
141-
eos_ids->insert(tokenizer_->eos_tok());
153+
142154
int32_t vocab_size = tokenizer_->vocab_size();
143155
decoder_runner_ =
144156
std::make_unique<DecoderRunner>(module_.get(), vocab_size, temperature_);

examples/qualcomm/oss_scripts/llama/runner/runner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <executorch/extension/llm/runner/stats.h>
2525
#include <executorch/extension/module/module.h>
2626
#include <pytorch/tokenizers/tokenizer.h>
27+
2728
namespace example {
2829

2930
enum LlamaVersion {
@@ -41,7 +42,8 @@ class Runner {
4142
const std::string& kv_updater = "SmartMask",
4243
const int ngram = 0,
4344
const int window = 0,
44-
const int gcap = 0);
45+
const int gcap = 0,
46+
std::unique_ptr<tokenizers::Tokenizer> tokenizer = nullptr);
4547

4648
bool is_loaded() const;
4749
executorch::runtime::Error load();

examples/qualcomm/oss_scripts/llama/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def define_common_targets():
3333
"//executorch/extension/evalue_util:print_evalue",
3434
"//executorch/backends/qualcomm/runtime:runtime",
3535
"//pytorch/tokenizers:llama2c_tokenizer",
36+
"//pytorch/tokenizers:regex_lookahead",
37+
"//pytorch/tokenizers:tiktoken",
3638
],
3739
external_deps = [
3840
"gflags",

0 commit comments

Comments
 (0)