Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 38 additions & 26 deletions examples/qualcomm/oss_scripts/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,16 @@ Runner::Runner(
const std::string& kv_updater,
const int ngram,
const int window,
const int gcap)
: tokenizer_path_(tokenizer_path),
const int gcap,
std::unique_ptr<tokenizers::Tokenizer> tokenizer)
: ngram_(ngram),
window_(window),
gcap_(gcap),
tokenizer_path_(tokenizer_path),
performance_output_path_(performance_output_path),
temperature_(temperature),
eval_mode_(static_cast<EvalMode>(eval_mode)),
ngram_(ngram),
window_(window),
gcap_(gcap) {
tokenizer_(std::move(tokenizer)) {
module_ = std::make_unique<Module>(
model_path, Module::LoadMode::MmapUseMlockIgnoreErrors);
stats_.reset();
Expand Down Expand Up @@ -115,30 +117,40 @@ Error Runner::load() {
break;
}

// load tokenizer. Assuming tiktoken is the default tokenizer
tokenizer_ = get_tiktoken_for_llama();
auto err = tokenizer_->load(tokenizer_path_);
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>();
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
// fallback to BPE tokenizer.
if (err != tokenizers::Error::Ok) {
ET_LOG(
Info,
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
tokenizer_path_.c_str());
tokenizer_.reset();
tokenizer_ = std::make_unique<tokenizers::Llama2cTokenizer>();
err = tokenizer_->load(tokenizer_path_);
llama_version_ = LlamaVersion::kLlama2;
ET_CHECK_MSG(
err == tokenizers::Error::Ok,
"failed to load tokenizer %s",
tokenizer_path_.c_str());
} else {
// TODO: remove this once we could release the new tokens used for the
// tokenizer
if (tokenizer_ != nullptr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@limintang maybe let's add a TODO or comment here such that others know the context. @haowhsu-quic @shewu-quic this is for supporting an internal model, feel free to let us know your thoughts and we can address it.

eos_ids->insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]);
llama_version_ = LlamaVersion::kLlama3;
eos_ids->insert(tokenizer_->encode("<|eot|>", 0, 0).get()[0]);
eos_ids->insert(tokenizer_->encode("<|end_of_text|>", 0, 0).get()[0]);
} else {
// load tokenizer. Assuming tiktoken is the default tokenizer
tokenizer_ = get_tiktoken_for_llama();
auto err = tokenizer_->load(tokenizer_path_);
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>();
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
// fallback to BPE tokenizer.
if (err != tokenizers::Error::Ok) {
ET_LOG(
Info,
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
tokenizer_path_.c_str());
tokenizer_.reset();
tokenizer_ = std::make_unique<tokenizers::Llama2cTokenizer>();
err = tokenizer_->load(tokenizer_path_);
llama_version_ = LlamaVersion::kLlama2;
ET_CHECK_MSG(
err == tokenizers::Error::Ok,
"failed to load tokenizer %s",
tokenizer_path_.c_str());
} else {
eos_ids->insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]);
llama_version_ = LlamaVersion::kLlama3;
}
eos_ids->insert(tokenizer_->eos_tok());
}
eos_ids->insert(tokenizer_->eos_tok());

int32_t vocab_size = tokenizer_->vocab_size();
decoder_runner_ =
std::make_unique<DecoderRunner>(module_.get(), vocab_size, temperature_);
Expand Down
4 changes: 3 additions & 1 deletion examples/qualcomm/oss_scripts/llama/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/module/module.h>
#include <pytorch/tokenizers/tokenizer.h>

namespace example {

enum LlamaVersion {
Expand All @@ -41,7 +42,8 @@ class Runner {
const std::string& kv_updater = "SmartMask",
const int ngram = 0,
const int window = 0,
const int gcap = 0);
const int gcap = 0,
std::unique_ptr<tokenizers::Tokenizer> tokenizer = nullptr);

bool is_loaded() const;
executorch::runtime::Error load();
Expand Down
2 changes: 2 additions & 0 deletions examples/qualcomm/oss_scripts/llama/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def define_common_targets():
"//executorch/extension/evalue_util:print_evalue",
"//executorch/backends/qualcomm/runtime:runtime",
"//pytorch/tokenizers:llama2c_tokenizer",
"//pytorch/tokenizers:regex_lookahead",
"//pytorch/tokenizers:tiktoken",
],
external_deps = [
"gflags",
Expand Down
Loading