99// A llama 3.2 runner that includes preprocessing and post processing
1010// logic. The module takes in a string as input and emits a string as output.
1111
12+ #include < executorch/examples/models/llama/runner/runner.h>
1213#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
1314#include < executorch/examples/qualcomm/oss_scripts/llama/runner/client_mem.h>
1415#include < executorch/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h>
@@ -58,7 +59,7 @@ void print_performance_report(
5859 outfile << num_tok;
5960 outfile.close ();
6061 } else {
61- ET_CHECK_MSG ( false , " Error saving the inference speed file" );
62+ ET_LOG (Error , " Error saving the inference speed file" );
6263 }
6364}
6465
@@ -83,13 +84,6 @@ void save_logits(
8384
8485} // namespace
8586
86- std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer (
87- const std::string& tokenizer_path,
88- Version version) {
89- auto special_tokens = get_special_tokens (version);
90- return llm::load_tokenizer (tokenizer_path, std::move (special_tokens));
91- }
92-
9387template <typename T>
9488Runner<T>::Runner(
9589 std::unique_ptr<executorch::extension::Module> module ,
@@ -179,7 +173,8 @@ Error Runner<T>::load() {
179173 eos_ids->insert (tokenizer_->encode (" <|eot|>" , 0 , 0 ).get ()[0 ]);
180174 eos_ids->insert (tokenizer_->encode (" <|end_of_text|>" , 0 , 0 ).get ()[0 ]);
181175 } else {
182- tokenizer_ = load_llama_tokenizer (tokenizer_path_, Version::Default);
176+ tokenizer_ =
177+ example::load_llama_tokenizer (tokenizer_path_, Version::Default);
183178 if (tokenizer_ == nullptr ) {
184179 ET_LOG (
185180 Error, " Failed to load tokenizer with %s" , tokenizer_path_.c_str ());
@@ -321,13 +316,30 @@ Error Runner<T>::load() {
321316
322317template <typename T>
323318Error Runner<T>::generate(
319+ const std::string& prompt,
320+ const llm::GenerationConfig& config,
321+ std::function<void (const std::string&)> token_callback,
322+ std::function<void (const Stats&)> stats_callback) {
323+ return generate_from_pos (prompt, 0 , config, token_callback, stats_callback);
324+ }
325+
326+ Error Runner::generate_from_pos (
327+ const std::string& prompt,
328+ int64_t start_pos,
329+ const llm::GenerationConfig& config,
330+ std::function<void (const std::string&)> token_callback,
331+ std::function<void(const Stats&)> stats_callback) {
332+ // TODO: currently only support start_pos == 0
333+ return generate_from_prompt_or_file (
334+ prompt, false , config, token_callback, stats_callback);
335+ }
336+
337+ Error Runner::generate_from_prompt_or_file (
324338 const std::string& prompt,
325339 bool tokenized_prompt,
326- int32_t seq_len ,
340+ const llm::GenerationConfig& config ,
327341 std::function<void (const std::string&)> token_callback,
328- std::function<void (const Stats&)> stats_callback,
329- bool echo,
330- bool warming) {
342+ std::function<void(const Stats&)> stats_callback) {
331343 ET_CHECK_MSG (!prompt.empty (), " prompt cannot be null" );
332344 if (!is_loaded ()) {
333345 stats_.model_load_start_ms = time_in_ms ();
@@ -336,6 +348,7 @@ Error Runner<T>::generate(
336348 }
337349 stats_.inference_start_ms = time_in_ms ();
338350
351+ int32_t seq_len = config.seq_len ;
339352 seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
340353 int32_t n_bos = (cur_pos_ == 0 ) ? 1 : 0 ;
341354
@@ -374,7 +387,7 @@ Error Runner<T>::generate(
374387 " sequence length exceeded - please increase the seq_len value" );
375388
376389 // Prompt Processor first
377- if (token_callback) {
390+ if (token_callback && config. echo ) {
378391 token_callback (prompt);
379392 }
380393 bool dump_logits = dump_logits_path_.empty () ? false : true ;
0 commit comments