1111// The module takes in a string as input and emits a string as output.
1212
1313#include < executorch/examples/models/llama/runner/runner.h>
14-
15- #include < executorch/extension/llm/runner/util.h>
14+ #include < executorch/extension/module/module.h>
1615
1716#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
1817#include < pytorch/tokenizers/hf_tokenizer.h>
@@ -26,41 +25,14 @@ using ::executorch::runtime::Result;
2625
2726namespace llm = ::executorch::extension::llm;
2827
29- namespace {
30- static constexpr auto kEnableDynamicShape = " enable_dynamic_shape" ;
31- static constexpr auto kBosId = " get_bos_id" ;
32- static constexpr auto kEosIds = " get_eos_ids" ;
33- static constexpr auto kMaxSeqLen = " get_max_seq_len" ;
34- static constexpr auto kMaxContextLen = " get_max_context_len" ;
35- static constexpr auto kVocabSize = " get_vocab_size" ;
36- static constexpr auto kUseKVCache = " use_kv_cache" ;
37- static constexpr auto kUseSDPAWithKVCache = " use_sdpa_with_kv_cache" ;
38-
39- std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer (
40- const std::string& tokenizer_path) {
41- auto json_tokenizer = std::make_unique<tokenizers::HFTokenizer>();
42- if (json_tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
43- ET_LOG (Info, " Loaded json tokenizer" );
44- return json_tokenizer;
45- }
46-
47- auto tiktoken_tokenizer = get_tiktoken_for_llama ();
48- if (tiktoken_tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
49- ET_LOG (Info, " Loaded TikToken tokenizer" );
50- return tiktoken_tokenizer;
51- }
52-
53- auto bpe_tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
54- if (bpe_tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
55- ET_LOG (Info, " Loaded BPE tokenizer" );
56- return bpe_tokenizer;
57- }
58-
59- return nullptr ;
28+ std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer (
29+ const std::string& tokenizer_path,
30+ Version version) {
31+ auto special_tokens = get_special_tokens (version);
32+ return llm::load_tokenizer (tokenizer_path, std::move (special_tokens));
6033}
61- } // namespace
6234
63- std::unique_ptr<Runner> Runner::create (
35+ std::unique_ptr<llm::TextLLMRunner> create_llama_runner (
6436 const std::string& model_path,
6537 const std::string& tokenizer_path,
6638 std::optional<const std::string> data_path,
@@ -71,309 +43,19 @@ std::unique_ptr<Runner> Runner::create(
7143 model_path.c_str (),
7244 tokenizer_path.c_str ());
7345
74- // Create the Module
75- std::unique_ptr<Module> module ;
76- if (data_path.has_value ()) {
77- module = std::make_unique<Module>(
78- model_path, data_path.value (), Module::LoadMode::File);
79- } else {
80- module = std::make_unique<Module>(model_path, Module::LoadMode::File);
81- }
82-
83- // Initialize metadata with default values
84- std::unordered_map<std::string, int64_t > metadata ({
85- {kEnableDynamicShape , false },
86- {kMaxSeqLen , 128 },
87- {kMaxContextLen , 128 },
88- {kUseKVCache , true },
89- {kUseSDPAWithKVCache , false },
90- });
91-
9246 // Create and load tokenizer
9347 std::unique_ptr<::tokenizers::Tokenizer> tokenizer =
94- load_tokenizer (tokenizer_path);
48+ load_llama_tokenizer (tokenizer_path, Version::Default );
9549
96- // Fallback to BPE tokenizer if tiktoken fails
9750 if (tokenizer == nullptr ) {
9851 ET_LOG (
9952 Info,
10053 " Failed to load %s as a Tiktoken, Sentencepiece or Llama2.c tokenizer, make sure the artifact is one of these types" ,
10154 tokenizer_path.c_str ());
10255 return nullptr ;
10356 }
104-
105- ET_LOG (Info, " Reading metadata from model" );
106-
107- // Set tokenizer-related metadata
108- metadata[kBosId ] = tokenizer->bos_tok ();
109- auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>(
110- std::unordered_set<uint64_t >{tokenizer->eos_tok ()});
111- metadata[kVocabSize ] = tokenizer->vocab_size ();
112-
113- // Read metadata from the model
114- auto method_names_result = module ->method_names ();
115- if (method_names_result.error () != Error::Ok) {
116- ET_LOG (Error, " Failed reading method names" );
117- return nullptr ;
118- }
119- const auto method_names = method_names_result.get ();
120-
121- for (auto & pair : metadata) {
122- const auto & method_name = pair.first ;
123- auto & value = pair.second ;
124-
125- if (method_names.count (method_name)) {
126- auto get_result = module ->get (method_name);
127- value = get_result.get ().toScalar ().to <decltype (metadata)::mapped_type>();
128- } else {
129- ET_LOG (
130- Info,
131- " Method %s not found, using the default value %" PRId64,
132- method_name.c_str (),
133- value);
134- }
135- ET_LOG (Info, " Metadata: %s = %" PRId64, method_name.c_str (), value);
136- }
137-
138- // Get EOS IDs if available
139- if (method_names.count (kEosIds )) {
140- eos_ids->clear ();
141- auto execute_result = module ->execute (kEosIds );
142- if (execute_result.error () != Error::Ok) {
143- ET_LOG (Error, " Failed to execute %s" , kEosIds );
144- return nullptr ;
145- }
146- for (const auto & eos_id : execute_result.get ()) {
147- auto value = eos_id.toScalar ().to <int64_t >();
148- eos_ids->emplace (value);
149- ET_LOG (Info, " eos_id = %" PRId64, value);
150- }
151- }
152-
153- // Create text_decoder_runner. Use a shared_ptr so that it can be shared with
154- // TextPrefiller and TextTokenGenerator
155- auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
156- module .get (), metadata.at (kUseKVCache ));
157-
158- // Create text_prefiller
159- auto text_prefiller = std::make_unique<llm::TextPrefiller>(
160- text_decoder_runner.get (),
161- metadata.at (kUseKVCache ),
162- metadata.at (kEnableDynamicShape ),
163- metadata.at (kMaxSeqLen ));
164-
165- // Create text_token_generator with stats
166- auto stats = std::make_unique<llm::Stats>();
167- auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
168- tokenizer.get (),
169- text_decoder_runner.get (),
170- metadata.at (kUseKVCache ),
171- std::move (eos_ids),
172- stats.get ());
173-
174- // Create and return the Runner instance
175- return std::make_unique<Runner>(
176- std::move (metadata),
177- std::move (tokenizer),
178- std::move (module ),
179- std::move (text_decoder_runner),
180- std::move (text_prefiller),
181- std::move (text_token_generator),
182- std::move (stats),
183- temperature);
184- }
185-
186- Runner::Runner (
187- std::unordered_map<std::string, int64_t > metadata,
188- std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
189- std::unique_ptr<::executorch::extension::Module> module ,
190- std::unique_ptr<::executorch::extension::llm::TextDecoderRunner>
191- text_decoder_runner,
192- std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
193- std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
194- text_token_generator,
195- std::unique_ptr<::executorch::extension::llm::Stats> stats,
196- float temperature)
197- : tokenizer_(std::move(tokenizer)),
198- metadata_ (std::move(metadata)),
199- module_(std::move(module )),
200- text_decoder_runner_(std::move(text_decoder_runner)),
201- text_prefiller_(std::move(text_prefiller)),
202- text_token_generator_(std::move(text_token_generator)),
203- stats_(std::move(stats)),
204- temperature_(temperature) {
205- // Note: This constructor assumes that text_prefiller and text_token_generator
206- // already have references to the Module and TextDecoderRunner they need
207- }
208-
209- bool Runner::is_loaded () const {
210- return text_prefiller_->is_loaded () && text_token_generator_->is_loaded ();
211- }
212-
213- Error Runner::load () {
214- if (is_loaded ()) {
215- return Error::Ok;
216- }
217- ET_CHECK_OK_OR_RETURN_ERROR (text_prefiller_->load ());
218- ET_CHECK_OK_OR_RETURN_ERROR (text_token_generator_->load ());
219- return Error::Ok;
220- }
221-
222- // Don't print with the same priority during warmup
223- #define RUNNER_ET_LOG (warmup, format, ...) \
224- if (warmup) { \
225- ET_LOG (Debug, format, __VA_ARGS__); \
226- } else { \
227- ET_LOG (Info, format, __VA_ARGS__); \
228- }
229-
230- Error Runner::generate (
231- const std::string& prompt,
232- const ::executorch::extension::llm::GenerationConfig& config,
233- std::function<void (const std::string&)> token_callback,
234- std::function<void(const llm::Stats&)> stats_callback) {
235- // Prepare the inputs.
236- // Use ones-initialized inputs.
237- ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
238- if (!is_loaded ()) {
239- stats_->model_load_start_ms = llm::time_in_ms ();
240- ET_CHECK_OK_OR_RETURN_ERROR (load ());
241- stats_->model_load_end_ms = llm::time_in_ms ();
242- }
243-
244- if (config.warming ) {
245- ET_LOG (Info, " Doing a warmup run..." );
246- }
247-
248- RUNNER_ET_LOG (
249- config.warming ,
250- " RSS after loading model: %f MiB (0 if unsupported)" ,
251- llm::get_rss_bytes () / 1024.0 / 1024.0 );
252-
253- // Wrap the token_callback with print function
254- std::function<void (const std::string&)> wrapped_callback =
255- [token_callback, config](const std::string& piece) {
256- if (!config.warming ) {
257- llm::safe_printf (piece.c_str ());
258- fflush (stdout);
259- }
260- if (token_callback) {
261- token_callback (piece);
262- }
263- };
264- // First token time only measures the time it takes to encode the prompt and
265- // return a response token.
266-
267- stats_->inference_start_ms = llm::time_in_ms ();
268- shouldStop_ = false ;
269-
270- ::tokenizers::Result<std::vector<uint64_t >> encode_res = tokenizer_->encode (
271- prompt,
272- /* bos */ 0 ,
273- /* eos */ 0 );
274-
275- ET_CHECK_TK_OK_OR_RETURN_ERROR (
276- encode_res.error (), " Failed to encode prompt %s" , prompt.c_str ());
277-
278- // encode the (string) prompt into tokens sequence
279- std::vector<uint64_t > prompt_tokens = encode_res.get ();
280- int num_prompt_tokens = prompt_tokens.size ();
281-
282- ET_CHECK_MSG (num_prompt_tokens >= 1 , " Expected at least 1 prompt token" );
283- ET_CHECK_MSG (
284- num_prompt_tokens < metadata_.at (kMaxContextLen ),
285- " num_prompt_tokens %d >= max_seq_len_ %" PRId64
286- " , Max seq length exceeded - please increase max seq len value in your export script" ,
287- num_prompt_tokens,
288- metadata_.at (kMaxContextLen ));
289-
290- // Determine max_new_tokens using the GenerationConfig's resolve method
291- int max_new_tokens = config.resolve_max_new_tokens (
292- metadata_.at (kMaxContextLen ), num_prompt_tokens);
293-
294- ET_LOG (Info, " Max new tokens resolved: %d" , max_new_tokens);
295-
296- // Prefill first
297- // Here feed all tokens to the model and get the next predicted token
298- // after the prompt. After that we will enter generate loop.
299-
300- // print prompts
301- if (config.echo ) {
302- wrapped_callback (prompt);
303- }
304- int64_t pos = 0 ;
305- auto prefill_res = text_prefiller_->prefill (prompt_tokens, pos);
306- ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
307- uint64_t cur_token = prefill_res.get ();
308- stats_->first_token_ms = llm::time_in_ms ();
309- stats_->prompt_eval_end_ms = llm::time_in_ms ();
310-
311- // print the first token from prefill. No prev_token so use cur_token for it.
312- wrapped_callback (
313- ET_UNWRAP_TOKENIZER (tokenizer_->decode (cur_token, cur_token)));
314- RUNNER_ET_LOG (
315- config.warming ,
316- " RSS after prompt prefill: %f MiB (0 if unsupported)" ,
317- llm::get_rss_bytes () / 1024.0 / 1024.0 );
318-
319- // start the main loop
320- prompt_tokens.push_back (cur_token);
321-
322- // Generate max_new_tokens - 1 because prefill already generated 1 token.
323- int64_t num_generated_tokens = ET_UNWRAP (text_token_generator_->generate (
324- prompt_tokens,
325- num_prompt_tokens,
326- max_new_tokens - 1 ,
327- temperature_ == -1 .0f ? config.temperature : temperature_,
328- wrapped_callback));
329-
330- stats_->inference_end_ms = llm::time_in_ms ();
331- if (!config.warming ) {
332- printf (" \n " );
333- }
334- RUNNER_ET_LOG (
335- config.warming ,
336- " RSS after finishing text generation: %f MiB (0 if unsupported)" ,
337- llm::get_rss_bytes () / 1024.0 / 1024.0 );
338-
339- if (num_generated_tokens == max_new_tokens) {
340- RUNNER_ET_LOG (config.warming , " Max new tokens %i reached!" , max_new_tokens);
341- }
342-
343- stats_->num_prompt_tokens = num_prompt_tokens;
344- stats_->num_generated_tokens = num_generated_tokens;
345-
346- if (config.warming ) {
347- ET_LOG (Info, " Warmup run finished!" );
348- } else {
349- // Do not print report during warmup
350- ::executorch::llm::print_report (*stats_);
351- }
352- if (stats_callback) {
353- stats_callback (*stats_);
354- }
355-
356- return Error::Ok;
357- }
358-
359- Error Runner::warmup (const std::string& prompt, int32_t max_new_tokens) {
360- // Create a GenerationConfig for warmup
361- llm::GenerationConfig config{
362- .echo = false , .max_new_tokens = max_new_tokens, .warming = true };
363-
364- // Call generate with the warmup config
365- Error err = generate (prompt, config);
366-
367- // Reset stats after warmup, not resetting the std::unique_ptr!
368- stats_->reset ();
369- return err;
57+ return llm::create_text_llm_runner (
58+ model_path, std::move (tokenizer), data_path);
37059}
37160
372- void Runner::stop () {
373- if (is_loaded ()) {
374- text_token_generator_->stop ();
375- } else {
376- ET_LOG (Error, " Token generator is not loaded, cannot stop" );
377- }
378- }
37961} // namespace example
0 commit comments