44 *
55 * This source code is licensed under the BSD-style license found in the
66 * LICENSE file in the root directory of this source tree.
7+ * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
78 */
89
910// A simple llama2 runner that includes preprocessing and post processing logic.
1011// The module takes in a string as input and emits a string as output.
1112
1213#include < executorch/examples/models/llama/runner/runner.h>
1314
14- #include < algorithm>
15- #include < ctime>
16-
1715#include < executorch/extension/llm/runner/util.h>
1816
1917#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
@@ -62,125 +60,155 @@ std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer(
6260}
6361} // namespace
6462
65- Runner:: Runner (
63+ std::unique_ptr< Runner> Runner::create (
6664 const std::string& model_path,
6765 const std::string& tokenizer_path,
68- std::optional<const std::string> data_path)
69- // NOTE: we observed ~2x loading performance increase on iPhone 15
70- // and a ~5% improvement on Galaxy S22 by switching to
71- // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
72- : tokenizer_path_(tokenizer_path),
73- metadata_ ({
74- {kEnableDynamicShape , false },
75- {kMaxSeqLen , 128 },
76- {kMaxContextLen , 128 },
77- {kUseKVCache , true },
78- {kUseSDPAWithKVCache , false },
79- }) {
80- if (data_path.has_value ()) {
81- module_ = std::make_unique<Module>(
82- model_path, data_path.value (), Module::LoadMode::File);
83- } else {
84- module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
85- }
66+ std::optional<const std::string> data_path,
67+ float temperature) {
8668 ET_LOG (
8769 Info,
8870 " Creating LLaMa runner: model_path=%s, tokenizer_path=%s" ,
8971 model_path.c_str (),
9072 tokenizer_path.c_str ());
91- }
9273
93- [[deprecated(
94- " This constructor is deprecated. Use the constructor without temperature parameter instead." )]]
95- Runner::Runner (
96- const std::string& model_path,
97- const std::string& tokenizer_path,
98- const float temperature,
99- std::optional<const std::string> data_path)
100- : Runner(model_path, tokenizer_path, std::move(data_path)) {
101- temperature_ = temperature;
102- }
103-
104- bool Runner::is_loaded () const {
105- return module_->is_loaded () && tokenizer_ && text_decoder_runner_ &&
106- text_prefiller_ && text_token_generator_;
107- }
108-
109- Error Runner::load () {
110- if (is_loaded ()) {
111- return Error::Ok;
74+ // Create the Module
75+ std::unique_ptr<Module> module ;
76+ if (data_path.has_value ()) {
77+ module = std::make_unique<Module>(
78+ model_path, data_path.value (), Module::LoadMode::File);
79+ } else {
80+ module = std::make_unique<Module>(model_path, Module::LoadMode::File);
11281 }
113- ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
11482
115- // Load tokenizer.
116- tokenizer_ = load_tokenizer (tokenizer_path_);
117- if (tokenizer_ == nullptr ) {
83+ // Initialize metadata with default values
84+ std::unordered_map<std::string, int64_t > metadata ({
85+ {kEnableDynamicShape , false },
86+ {kMaxSeqLen , 128 },
87+ {kMaxContextLen , 128 },
88+ {kUseKVCache , true },
89+ {kUseSDPAWithKVCache , false },
90+ });
91+
92+ // Create and load tokenizer
93+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer =
94+ load_tokenizer (tokenizer_path);
95+
96+ // Fallback to BPE tokenizer if tiktoken fails
97+ if (tokenizer == nullptr ) {
11898 ET_LOG (
11999 Info,
120- " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
121- tokenizer_path_.c_str ());
122- tokenizer_.reset ();
123- // @lint-ignore CLANGTIDY facebook-hte-Deprecated
124- tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
125- auto err = tokenizer_->load (tokenizer_path_);
126- ET_CHECK_TK_OK_OR_RETURN_ERROR (
127- err,
128- " Failed to load %s as a llama2.c tokenizer artifact" ,
129- tokenizer_path_.c_str ());
130- return ::executorch::runtime::Error::InvalidArgument;
100+ " Failed to load %s as a Tiktoken, Sentencepiece or Llama2.c tokenizer, make sure the artifact is one of these types" ,
101+ tokenizer_path.c_str ());
102+ return nullptr ;
131103 }
132104
133105 ET_LOG (Info, " Reading metadata from model" );
134106
135- metadata_[kBosId ] = tokenizer_->bos_tok ();
107+ // Set tokenizer-related metadata
108+ metadata[kBosId ] = tokenizer->bos_tok ();
136109 auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>(
137- std::unordered_set<uint64_t >{tokenizer_->eos_tok ()});
138- metadata_[kVocabSize ] = tokenizer_->vocab_size ();
139-
140- const auto method_names =
141- ET_UNWRAP (module_->method_names (), " Failed reading method names" );
110+ std::unordered_set<uint64_t >{tokenizer->eos_tok ()});
111+ metadata[kVocabSize ] = tokenizer->vocab_size ();
112+
113+ // Read metadata from the model
114+ auto method_names_result = module ->method_names ();
115+ if (method_names_result.error () != Error::Ok) {
116+ ET_LOG (Error, " Failed reading method names" );
117+ return nullptr ;
118+ }
119+ const auto method_names = method_names_result.get ();
142120
143- for (auto & pair : metadata_ ) {
121+ for (auto & pair : metadata ) {
144122 const auto & method_name = pair.first ;
145123 auto & value = pair.second ;
146124
147125 if (method_names.count (method_name)) {
148- value = ET_UNWRAP (module_->get (method_name))
149- .toScalar ()
150- .to <decltype (metadata_)::mapped_type>();
126+ auto get_result = module ->get (method_name);
127+ value = get_result.get ().toScalar ().to <decltype (metadata)::mapped_type>();
151128 } else {
152129 ET_LOG (
153130 Info,
154- " Methond %s not found, using the default value %" PRId64,
131+ " Method %s not found, using the default value %" PRId64,
155132 method_name.c_str (),
156133 value);
157134 }
158135 ET_LOG (Info, " Metadata: %s = %" PRId64, method_name.c_str (), value);
159136 }
137+
138+ // Get EOS IDs if available
160139 if (method_names.count (kEosIds )) {
161140 eos_ids->clear ();
162- for (const auto & eos_id : ET_UNWRAP (module_->execute (kEosIds ))) {
141+ auto execute_result = module ->execute (kEosIds );
142+ if (execute_result.error () != Error::Ok) {
143+ ET_LOG (Error, " Failed to execute %s" , kEosIds );
144+ return nullptr ;
145+ }
146+ for (const auto & eos_id : execute_result.get ()) {
163147 auto value = eos_id.toScalar ().to <int64_t >();
164148 eos_ids->emplace (value);
165149 ET_LOG (Info, " eos_id = %" PRId64, value);
166150 }
167151 }
168- // @lint-ignore CLANGTIDY facebook-hte-Deprecated
169- text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
170- module_.get (), metadata_.at (kUseKVCache ));
171- text_prefiller_ = std::make_unique<llm::TextPrefiller>(
172- text_decoder_runner_.get (),
173- metadata_.at (kUseKVCache ),
174- metadata_.at (kEnableDynamicShape ),
175- metadata_.at (kMaxSeqLen ));
176-
177- text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
178- tokenizer_.get (),
179- text_decoder_runner_.get (),
180- metadata_.at (kUseKVCache ),
152+
153+ // Create text_decoder_runner. Use a shared_ptr so that it can be shared with
154+ // TextPrefiller and TextTokenGenerator
155+ auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
156+ module .get (), metadata.at (kUseKVCache ));
157+
158+ // Create text_prefiller
159+ auto text_prefiller = std::make_unique<llm::TextPrefiller>(
160+ text_decoder_runner.get (),
161+ metadata.at (kUseKVCache ),
162+ metadata.at (kEnableDynamicShape ),
163+ metadata.at (kMaxSeqLen ));
164+
165+ // Create text_token_generator with stats
166+ auto stats = std::make_unique<llm::Stats>();
167+ auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
168+ tokenizer.get (),
169+ text_decoder_runner.get (),
170+ metadata.at (kUseKVCache ),
181171 std::move (eos_ids),
182- &stats_);
172+ stats.get ());
173+
174+ // Create and return the Runner instance
175+ return std::make_unique<Runner>(
176+ std::move (metadata),
177+ std::move (tokenizer),
178+ std::move (text_prefiller),
179+ std::move (text_token_generator),
180+ std::move (stats),
181+ temperature);
182+ }
183+
184+ Runner::Runner (
185+ std::unordered_map<std::string, int64_t > metadata,
186+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
187+ std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
188+ std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
189+ text_token_generator,
190+ std::unique_ptr<::executorch::extension::llm::Stats> stats,
191+ float temperature)
192+ : tokenizer_(std::move(tokenizer)),
193+ metadata_ (std::move(metadata)),
194+ text_prefiller_(std::move(text_prefiller)),
195+ text_token_generator_(std::move(text_token_generator)),
196+ stats_(std::move(stats)),
197+ temperature_(temperature) {
198+ // Note: This constructor assumes that text_prefiller and text_token_generator
199+ // already have references to the Module and TextDecoderRunner they need
200+ }
201+
202+ bool Runner::is_loaded () const {
203+ return text_prefiller_->is_loaded () && text_token_generator_->is_loaded ();
204+ }
183205
206+ Error Runner::load () {
207+ if (is_loaded ()) {
208+ return Error::Ok;
209+ }
210+ ET_CHECK_OK_OR_RETURN_ERROR (text_prefiller_->load ());
211+ ET_CHECK_OK_OR_RETURN_ERROR (text_token_generator_->load ());
184212 return Error::Ok;
185213}
186214
@@ -201,9 +229,9 @@ Error Runner::generate(
201229 // Use ones-initialized inputs.
202230 ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
203231 if (!is_loaded ()) {
204- stats_. model_load_start_ms = llm::time_in_ms ();
232+ stats_-> model_load_start_ms = llm::time_in_ms ();
205233 ET_CHECK_OK_OR_RETURN_ERROR (load ());
206- stats_. model_load_end_ms = llm::time_in_ms ();
234+ stats_-> model_load_end_ms = llm::time_in_ms ();
207235 }
208236
209237 if (config.warming ) {
@@ -229,7 +257,7 @@ Error Runner::generate(
229257 // First token time only measures the time it takes to encode the prompt and
230258 // return a response token.
231259
232- stats_. inference_start_ms = llm::time_in_ms ();
260+ stats_-> inference_start_ms = llm::time_in_ms ();
233261 shouldStop_ = false ;
234262
235263 ::tokenizers::Result<std::vector<uint64_t >> encode_res = tokenizer_->encode (
@@ -270,8 +298,8 @@ Error Runner::generate(
270298 auto prefill_res = text_prefiller_->prefill (prompt_tokens, pos);
271299 ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
272300 uint64_t cur_token = prefill_res.get ();
273- stats_. first_token_ms = llm::time_in_ms ();
274- stats_. prompt_eval_end_ms = llm::time_in_ms ();
301+ stats_-> first_token_ms = llm::time_in_ms ();
302+ stats_-> prompt_eval_end_ms = llm::time_in_ms ();
275303
276304 // print the first token from prefill. No prev_token so use cur_token for it.
277305 wrapped_callback (
@@ -292,7 +320,7 @@ Error Runner::generate(
292320 temperature_ == -1 .0f ? config.temperature : temperature_,
293321 wrapped_callback));
294322
295- stats_. inference_end_ms = llm::time_in_ms ();
323+ stats_-> inference_end_ms = llm::time_in_ms ();
296324 if (!config.warming ) {
297325 printf (" \n " );
298326 }
@@ -305,17 +333,17 @@ Error Runner::generate(
305333 RUNNER_ET_LOG (config.warming , " Max new tokens %i reached!" , max_new_tokens);
306334 }
307335
308- stats_. num_prompt_tokens = num_prompt_tokens;
309- stats_. num_generated_tokens = num_generated_tokens;
336+ stats_-> num_prompt_tokens = num_prompt_tokens;
337+ stats_-> num_generated_tokens = num_generated_tokens;
310338
311339 if (config.warming ) {
312340 ET_LOG (Info, " Warmup run finished!" );
313341 } else {
314342 // Do not print report during warmup
315- ::executorch::llm::print_report (stats_);
343+ ::executorch::llm::print_report (* stats_);
316344 }
317345 if (stats_callback) {
318- stats_callback (stats_);
346+ stats_callback (* stats_);
319347 }
320348
321349 return Error::Ok;
0 commit comments