1111
1212#include < executorch/examples/models/llama/runner/runner.h>
1313
14- #include < algorithm>
15- #include < ctime>
16-
1714#include < executorch/extension/llm/runner/util.h>
1815
1916#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
@@ -35,129 +32,161 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len";
3532static constexpr auto kMaxContextLen = " get_max_context_len" ;
3633static constexpr auto kVocabSize = " get_vocab_size" ;
3734static constexpr auto kUseKVCache = " use_kv_cache" ;
38- static constexpr auto kUseSDPAWithKVCache = " use_sdpa_with_kv_cache" ;
3935} // namespace
4036
41- Runner:: Runner (
37+ std::unique_ptr< Runner> Runner::create (
4238 const std::string& model_path,
4339 const std::string& tokenizer_path,
44- std::optional<const std::string> data_path)
45- // NOTE: we observed ~2x loading performance increase on iPhone 15
46- // and a ~5% improvement on Galaxy S22 by switching to
47- // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
48- : tokenizer_path_(tokenizer_path),
49- metadata_ ({
50- {kEnableDynamicShape , false },
51- {kMaxSeqLen , 128 },
52- {kMaxContextLen , 128 },
53- {kUseKVCache , true },
54- {kUseSDPAWithKVCache , false },
55- }) {
56- if (data_path.has_value ()) {
57- module_ = std::make_unique<Module>(
58- model_path, data_path.value (), Module::LoadMode::File);
59- } else {
60- module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
61- }
40+ std::optional<const std::string> data_path,
41+ float temperature) {
6242 ET_LOG (
6343 Info,
6444 " Creating LLaMa runner: model_path=%s, tokenizer_path=%s" ,
6545 model_path.c_str (),
6646 tokenizer_path.c_str ());
67- }
6847
69- [[deprecated(
70- " This constructor is deprecated. Use the constructor without temperature parameter instead." )]]
71- Runner::Runner (
72- const std::string& model_path,
73- const std::string& tokenizer_path,
74- const float temperature,
75- std::optional<const std::string> data_path)
76- : Runner(model_path, tokenizer_path, std::move(data_path)) {
77- temperature_ = temperature;
78- }
48+ // Create the Module
49+ std::unique_ptr<Module> module ;
50+ if (data_path.has_value ()) {
51+ module = std::make_unique<Module>(
52+ model_path, data_path.value (), Module::LoadMode::File);
53+ } else {
54+ module = std::make_unique<Module>(model_path, Module::LoadMode::File);
55+ }
7956
80- bool Runner::is_loaded () const {
81- return module_->is_loaded () && tokenizer_ && text_decoder_runner_ &&
82- text_prefiller_ && text_token_generator_;
83- }
57+ // Initialize metadata with default values
58+ std::unordered_map<std::string, int64_t > metadata ({
59+ {kEnableDynamicShape , false },
60+ {kMaxSeqLen , 128 },
61+ {kMaxContextLen , 128 },
62+ {kUseKVCache , true },
63+ });
8464
85- Error Runner::load () {
86- if (is_loaded ()) {
87- return Error::Ok;
88- }
89- ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
90- // load tokenizer. Assuming tiktoken is the default tokenizer
91- tokenizer_ = nullptr ;
92- tokenizer_ = get_tiktoken_for_llama ();
93- ::tokenizers::Error err = tokenizer_->load (tokenizer_path_);
94- // Rely on tiktoken to throw error if the artifact is incompatible. Then we
95- // fallback to BPE tokenizer.
96- if (err != ::tokenizers::Error::Ok) {
65+ // Create and load tokenizer
66+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer = get_tiktoken_for_llama ();
67+ ::tokenizers::Error tk_err = tokenizer->load (tokenizer_path);
68+
69+ // Fallback to BPE tokenizer if tiktoken fails
70+ if (tk_err != ::tokenizers::Error::Ok) {
9771 ET_LOG (
9872 Info,
9973 " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
100- tokenizer_path_.c_str ());
101- tokenizer_.reset ();
102- tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
103- err = tokenizer_->load (tokenizer_path_);
104- ET_CHECK_TK_OK_OR_RETURN_ERROR (
105- err,
106- " Failed to load %s as a llama2.c tokenizer artifact" ,
107- tokenizer_path_.c_str ());
74+ tokenizer_path.c_str ());
75+ tokenizer.reset ();
76+ tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
77+ tk_err = tokenizer->load (tokenizer_path);
78+ if (tk_err != ::tokenizers::Error::Ok) {
79+ ET_LOG (
80+ Error,
81+ " Failed to load %s as a llama2.c tokenizer artifact" ,
82+ tokenizer_path.c_str ());
83+ return nullptr ;
84+ }
10885 }
10986
11087 ET_LOG (Info, " Reading metadata from model" );
11188
112- metadata_[kBosId ] = tokenizer_->bos_tok ();
89+ // Set tokenizer-related metadata
90+ metadata[kBosId ] = tokenizer->bos_tok ();
11391 auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>(
114- std::unordered_set<uint64_t >{tokenizer_->eos_tok ()});
115- metadata_[kVocabSize ] = tokenizer_->vocab_size ();
116-
117- const auto method_names =
118- ET_UNWRAP (module_->method_names (), " Failed reading method names" );
92+ std::unordered_set<uint64_t >{tokenizer->eos_tok ()});
93+ metadata[kVocabSize ] = tokenizer->vocab_size ();
94+
95+ // Read metadata from the model
96+ auto method_names_result = module ->method_names ();
97+ if (method_names_result.error () != Error::Ok) {
98+ ET_LOG (Error, " Failed reading method names" );
99+ return nullptr ;
100+ }
101+ const auto method_names = method_names_result.get ();
119102
120- for (auto & pair : metadata_ ) {
103+ for (auto & pair : metadata ) {
121104 const auto & method_name = pair.first ;
122105 auto & value = pair.second ;
123106
124107 if (method_names.count (method_name)) {
125- value = ET_UNWRAP (module_->get (method_name))
126- .toScalar ()
127- .to <decltype (metadata_)::mapped_type>();
108+ auto get_result = module ->get (method_name);
109+ value = get_result.get ().toScalar ().to <decltype (metadata)::mapped_type>();
128110 } else {
129111 ET_LOG (
130112 Info,
131- " Methond %s not found, using the default value %" PRId64,
113+ " Method %s not found, using the default value %" PRId64,
132114 method_name.c_str (),
133115 value);
134116 }
135117 ET_LOG (Info, " Metadata: %s = %" PRId64, method_name.c_str (), value);
136118 }
119+
120+ // Get EOS IDs if available
137121 if (method_names.count (kEosIds )) {
138122 eos_ids->clear ();
139- for (const auto & eos_id : ET_UNWRAP (module_->execute (kEosIds ))) {
123+ auto execute_result = module ->execute (kEosIds );
124+ if (execute_result.error () != Error::Ok) {
125+ ET_LOG (Error, " Failed to execute %s" , kEosIds );
126+ return nullptr ;
127+ }
128+ for (const auto & eos_id : execute_result.get ()) {
140129 auto value = eos_id.toScalar ().to <int64_t >();
141130 eos_ids->emplace (value);
142131 ET_LOG (Info, " eos_id = %" PRId64, value);
143132 }
144133 }
145- // @lint-ignore CLANGTIDY facebook-hte-Deprecated
146- text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
147- module_.get (), metadata_.at (kUseKVCache ));
148- text_prefiller_ = std::make_unique<llm::TextPrefiller>(
149- text_decoder_runner_.get (),
150- metadata_.at (kUseKVCache ),
151- metadata_.at (kEnableDynamicShape ),
152- metadata_.at (kMaxSeqLen ));
153-
154- text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
155- tokenizer_.get (),
156- text_decoder_runner_.get (),
157- metadata_.at (kUseKVCache ),
134+
135+ // Create text_decoder_runner
136+ auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
137+ module .get (), metadata.at (kUseKVCache ));
138+
139+ // Create text_prefiller
140+ auto text_prefiller = std::make_unique<llm::TextPrefiller>(
141+ text_decoder_runner.get (),
142+ metadata.at (kUseKVCache ),
143+ metadata.at (kEnableDynamicShape ),
144+ metadata.at (kMaxSeqLen ));
145+
146+ // Create text_token_generator with stats
147+ auto stats = new llm::Stats ();
148+ auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
149+ tokenizer.get (),
150+ text_decoder_runner.get (),
151+ metadata.at (kUseKVCache ),
158152 std::move (eos_ids),
159- &stats_);
153+ stats);
154+
155+ // Create and return the Runner instance
156+ return std::make_unique<Runner>(
157+ std::move (metadata),
158+ std::move (tokenizer),
159+ std::move (text_prefiller),
160+ std::move (text_token_generator),
161+ temperature);
162+ }
160163
164+ Runner::Runner (
165+ std::unordered_map<std::string, int64_t > metadata,
166+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
167+ std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
168+ std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
169+ text_token_generator,
170+ float temperature)
171+ : tokenizer_(std::move(tokenizer)),
172+ metadata_ (std::move(metadata)),
173+ text_prefiller_(std::move(text_prefiller)),
174+ text_token_generator_(std::move(text_token_generator)),
175+ temperature_(temperature) {
176+ // Note: This constructor assumes that text_prefiller and text_token_generator
177+ // already have references to the Module and TextDecoderRunner they need
178+ }
179+
180+ bool Runner::is_loaded () const {
181+ return text_prefiller_->is_loaded () && text_token_generator_->is_loaded ();
182+ }
183+
184+ Error Runner::load () {
185+ if (is_loaded ()) {
186+ return Error::Ok;
187+ }
188+ ET_CHECK_OK_OR_RETURN_ERROR (text_prefiller_->load ());
189+ ET_CHECK_OK_OR_RETURN_ERROR (text_token_generator_->load ());
161190 return Error::Ok;
162191}
163192
0 commit comments