1111
1212#include < executorch/examples/models/llama/runner/runner.h>
1313
14- #include < algorithm>
15- #include < ctime>
16-
1714#include < executorch/extension/llm/runner/util.h>
1815
1916#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
@@ -35,130 +32,165 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len";
3532static constexpr auto kMaxContextLen = " get_max_context_len" ;
3633static constexpr auto kVocabSize = " get_vocab_size" ;
3734static constexpr auto kUseKVCache = " use_kv_cache" ;
38- static constexpr auto kUseSDPAWithKVCache = " use_sdpa_with_kv_cache" ;
3935} // namespace
4036
41- Runner:: Runner (
37+ std::unique_ptr< Runner> Runner::create (
4238 const std::string& model_path,
4339 const std::string& tokenizer_path,
44- std::optional<const std::string> data_path)
45- // NOTE: we observed ~2x loading performance increase on iPhone 15
46- // and a ~5% improvement on Galaxy S22 by switching to
47- // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
48- : tokenizer_path_(tokenizer_path),
49- metadata_ ({
50- {kEnableDynamicShape , false },
51- {kMaxSeqLen , 128 },
52- {kMaxContextLen , 128 },
53- {kUseKVCache , true },
54- {kUseSDPAWithKVCache , false },
55- }) {
56- if (data_path.has_value ()) {
57- module_ = std::make_unique<Module>(
58- model_path, data_path.value (), Module::LoadMode::File);
59- } else {
60- module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
61- }
40+ std::optional<const std::string> data_path,
41+ float temperature) {
6242 ET_LOG (
6343 Info,
6444 " Creating LLaMa runner: model_path=%s, tokenizer_path=%s" ,
6545 model_path.c_str (),
6646 tokenizer_path.c_str ());
67- }
6847
69- [[deprecated(
70- " This constructor is deprecated. Use the constructor without temperature parameter instead." )]]
71- Runner::Runner (
72- const std::string& model_path,
73- const std::string& tokenizer_path,
74- const float temperature,
75- std::optional<const std::string> data_path)
76- : Runner(model_path, tokenizer_path, std::move(data_path)) {
77- temperature_ = temperature;
78- }
48+ // Create the Module
49+ std::unique_ptr<Module> module ;
50+ if (data_path.has_value ()) {
51+ module = std::make_unique<Module>(
52+ model_path, data_path.value (), Module::LoadMode::File);
53+ } else {
54+ module = std::make_unique<Module>(model_path, Module::LoadMode::File);
55+ }
7956
80- bool Runner::is_loaded () const {
81- return module_->is_loaded () && tokenizer_ && text_decoder_runner_ &&
82- text_prefiller_ && text_token_generator_;
83- }
57+ // Initialize metadata with default values
58+ std::unordered_map<std::string, int64_t > metadata ({
59+ {kEnableDynamicShape , false },
60+ {kMaxSeqLen , 128 },
61+ {kMaxContextLen , 128 },
62+ {kUseKVCache , true },
63+ });
8464
85- Error Runner::load () {
86- if (is_loaded ()) {
87- return Error::Ok;
88- }
89- ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
90- // load tokenizer. Assuming tiktoken is the default tokenizer
91- tokenizer_ = nullptr ;
92- tokenizer_ = get_tiktoken_for_llama ();
93- ::tokenizers::Error err = tokenizer_->load (tokenizer_path_);
94- // Rely on tiktoken to throw error if the artifact is incompatible. Then we
95- // fallback to BPE tokenizer.
96- if (err != ::tokenizers::Error::Ok) {
65+ // Create and load tokenizer
66+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer = get_tiktoken_for_llama ();
67+ ::tokenizers::Error tk_err = tokenizer->load (tokenizer_path);
68+
69+ // Fallback to BPE tokenizer if tiktoken fails
70+ if (tk_err != ::tokenizers::Error::Ok) {
9771 ET_LOG (
9872 Info,
9973 " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
100- tokenizer_path_.c_str ());
101- tokenizer_.reset ();
102- // @lint-ignore CLANGTIDY facebook-hte-Deprecated
103- tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
104- err = tokenizer_->load (tokenizer_path_);
105- ET_CHECK_TK_OK_OR_RETURN_ERROR (
106- err,
107- " Failed to load %s as a llama2.c tokenizer artifact" ,
108- tokenizer_path_.c_str ());
74+ tokenizer_path.c_str ());
75+ tokenizer.reset ();
76+ tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
77+ tk_err = tokenizer->load (tokenizer_path);
78+ if (tk_err != ::tokenizers::Error::Ok) {
79+ ET_LOG (
80+ Error,
81+ " Failed to load %s as a llama2.c tokenizer artifact" ,
82+ tokenizer_path.c_str ());
83+ return nullptr ;
84+ }
10985 }
11086
11187 ET_LOG (Info, " Reading metadata from model" );
11288
113- metadata_[kBosId ] = tokenizer_->bos_tok ();
89+ // Set tokenizer-related metadata
90+ metadata[kBosId ] = tokenizer->bos_tok ();
11491 auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>(
115- std::unordered_set<uint64_t >{tokenizer_->eos_tok ()});
116- metadata_[kVocabSize ] = tokenizer_->vocab_size ();
117-
118- const auto method_names =
119- ET_UNWRAP (module_->method_names (), " Failed reading method names" );
92+ std::unordered_set<uint64_t >{tokenizer->eos_tok ()});
93+ metadata[kVocabSize ] = tokenizer->vocab_size ();
94+
95+ // Read metadata from the model
96+ auto method_names_result = module ->method_names ();
97+ if (method_names_result.error () != Error::Ok) {
98+ ET_LOG (Error, " Failed reading method names" );
99+ return nullptr ;
100+ }
101+ const auto method_names = method_names_result.get ();
120102
121- for (auto & pair : metadata_ ) {
103+ for (auto & pair : metadata ) {
122104 const auto & method_name = pair.first ;
123105 auto & value = pair.second ;
124106
125107 if (method_names.count (method_name)) {
126- value = ET_UNWRAP (module_->get (method_name))
127- .toScalar ()
128- .to <decltype (metadata_)::mapped_type>();
108+ auto get_result = module ->get (method_name);
109+ value = get_result.get ().toScalar ().to <decltype (metadata)::mapped_type>();
129110 } else {
130111 ET_LOG (
131112 Info,
132- " Methond %s not found, using the default value %" PRId64,
113+ " Method %s not found, using the default value %" PRId64,
133114 method_name.c_str (),
134115 value);
135116 }
136117 ET_LOG (Info, " Metadata: %s = %" PRId64, method_name.c_str (), value);
137118 }
119+
120+ // Get EOS IDs if available
138121 if (method_names.count (kEosIds )) {
139122 eos_ids->clear ();
140- for (const auto & eos_id : ET_UNWRAP (module_->execute (kEosIds ))) {
123+ auto execute_result = module ->execute (kEosIds );
124+ if (execute_result.error () != Error::Ok) {
125+ ET_LOG (Error, " Failed to execute %s" , kEosIds );
126+ return nullptr ;
127+ }
128+ for (const auto & eos_id : execute_result.get ()) {
141129 auto value = eos_id.toScalar ().to <int64_t >();
142130 eos_ids->emplace (value);
143131 ET_LOG (Info, " eos_id = %" PRId64, value);
144132 }
145133 }
146- // @lint-ignore CLANGTIDY facebook-hte-Deprecated
147- text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
148- module_.get (), metadata_.at (kUseKVCache ));
149- text_prefiller_ = std::make_unique<llm::TextPrefiller>(
150- text_decoder_runner_.get (),
151- metadata_.at (kUseKVCache ),
152- metadata_.at (kEnableDynamicShape ),
153- metadata_.at (kMaxSeqLen ));
154-
155- text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
156- tokenizer_.get (),
157- text_decoder_runner_.get (),
158- metadata_.at (kUseKVCache ),
134+
135+ // Create text_decoder_runner. Use a shared_ptr so that it can be shared with
136+ // TextPrefiller and TextTokenGenerator
137+ auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
138+ module .get (), metadata.at (kUseKVCache ));
139+
140+ // Create text_prefiller
141+ auto text_prefiller = std::make_unique<llm::TextPrefiller>(
142+ text_decoder_runner.get (),
143+ metadata.at (kUseKVCache ),
144+ metadata.at (kEnableDynamicShape ),
145+ metadata.at (kMaxSeqLen ));
146+
147+ // Create text_token_generator with stats
148+ auto stats = std::make_unique<llm::Stats>();
149+ auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
150+ tokenizer.get (),
151+ text_decoder_runner.get (),
152+ metadata.at (kUseKVCache ),
159153 std::move (eos_ids),
160- &stats_);
154+ stats.get ());
155+
156+ // Create and return the Runner instance
157+ return std::make_unique<Runner>(
158+ std::move (metadata),
159+ std::move (tokenizer),
160+ std::move (text_prefiller),
161+ std::move (text_token_generator),
162+ std::move (stats),
163+ temperature);
164+ }
161165
166+ Runner::Runner (
167+ std::unordered_map<std::string, int64_t > metadata,
168+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
169+ std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
170+ std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
171+ text_token_generator,
172+ std::unique_ptr<::executorch::extension::llm::Stats> stats,
173+ float temperature)
174+ : tokenizer_(std::move(tokenizer)),
175+ metadata_ (std::move(metadata)),
176+ text_prefiller_(std::move(text_prefiller)),
177+ text_token_generator_(std::move(text_token_generator)),
178+ stats_(std::move(stats)),
179+ temperature_(temperature) {
180+ // Note: This constructor assumes that text_prefiller and text_token_generator
181+ // already have references to the Module and TextDecoderRunner they need
182+ }
183+
184+ bool Runner::is_loaded () const {
185+ return text_prefiller_->is_loaded () && text_token_generator_->is_loaded ();
186+ }
187+
188+ Error Runner::load () {
189+ if (is_loaded ()) {
190+ return Error::Ok;
191+ }
192+ ET_CHECK_OK_OR_RETURN_ERROR (text_prefiller_->load ());
193+ ET_CHECK_OK_OR_RETURN_ERROR (text_token_generator_->load ());
162194 return Error::Ok;
163195}
164196
@@ -179,9 +211,9 @@ Error Runner::generate(
179211 // Use ones-initialized inputs.
180212 ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
181213 if (!is_loaded ()) {
182- stats_. model_load_start_ms = llm::time_in_ms ();
214+ stats_-> model_load_start_ms = llm::time_in_ms ();
183215 ET_CHECK_OK_OR_RETURN_ERROR (load ());
184- stats_. model_load_end_ms = llm::time_in_ms ();
216+ stats_-> model_load_end_ms = llm::time_in_ms ();
185217 }
186218
187219 if (config.warming ) {
@@ -207,7 +239,7 @@ Error Runner::generate(
207239 // First token time only measures the time it takes to encode the prompt and
208240 // return a response token.
209241
210- stats_. inference_start_ms = llm::time_in_ms ();
242+ stats_-> inference_start_ms = llm::time_in_ms ();
211243 shouldStop_ = false ;
212244
213245 ::tokenizers::Result<std::vector<uint64_t >> encode_res = tokenizer_->encode (
@@ -248,8 +280,8 @@ Error Runner::generate(
248280 auto prefill_res = text_prefiller_->prefill (prompt_tokens, pos);
249281 ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
250282 uint64_t cur_token = prefill_res.get ();
251- stats_. first_token_ms = llm::time_in_ms ();
252- stats_. prompt_eval_end_ms = llm::time_in_ms ();
283+ stats_-> first_token_ms = llm::time_in_ms ();
284+ stats_-> prompt_eval_end_ms = llm::time_in_ms ();
253285
254286 // print the first token from prefill. No prev_token so use cur_token for it.
255287 wrapped_callback (
@@ -270,7 +302,7 @@ Error Runner::generate(
270302 temperature_ == -1 .0f ? config.temperature : temperature_,
271303 wrapped_callback));
272304
273- stats_. inference_end_ms = llm::time_in_ms ();
305+ stats_-> inference_end_ms = llm::time_in_ms ();
274306 if (!config.warming ) {
275307 printf (" \n " );
276308 }
@@ -283,17 +315,17 @@ Error Runner::generate(
283315 RUNNER_ET_LOG (config.warming , " Max new tokens %i reached!" , max_new_tokens);
284316 }
285317
286- stats_. num_prompt_tokens = num_prompt_tokens;
287- stats_. num_generated_tokens = num_generated_tokens;
318+ stats_-> num_prompt_tokens = num_prompt_tokens;
319+ stats_-> num_generated_tokens = num_generated_tokens;
288320
289321 if (config.warming ) {
290322 ET_LOG (Info, " Warmup run finished!" );
291323 } else {
292324 // Do not print report during warmup
293- ::executorch::llm::print_report (stats_);
325+ ::executorch::llm::print_report (* stats_);
294326 }
295327 if (stats_callback) {
296- stats_callback (stats_);
328+ stats_callback (* stats_);
297329 }
298330
299331 return Error::Ok;
0 commit comments