99// A llama 3.2 runner that includes preprocessing and post processing
1010// logic. The module takes in a string as input and emits a string as output.
1111
12+ #include < executorch/examples/models/llama/runner/runner.h>
1213#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
1314#include < executorch/examples/qualcomm/oss_scripts/llama/runner/client_mem.h>
1415#include < executorch/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h>
@@ -58,7 +59,7 @@ void print_performance_report(
5859 outfile << num_tok;
5960 outfile.close ();
6061 } else {
61- ET_CHECK_MSG ( false , " Error saving the inference speed file" );
62+ ET_LOG (Error , " Error saving the inference speed file" );
6263 }
6364}
6465
@@ -83,13 +84,6 @@ void save_logits(
8384
8485} // namespace
8586
86- std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer (
87- const std::string& tokenizer_path,
88- Version version) {
89- auto special_tokens = get_special_tokens (version);
90- return llm::load_tokenizer (tokenizer_path, std::move (special_tokens));
91- }
92-
9387template <typename T>
9488Runner<T>::Runner(
9589 std::unique_ptr<executorch::extension::Module> module ,
@@ -181,7 +175,8 @@ Error Runner<T>::load() {
181175 eos_ids->insert (tokenizer_->encode (" <|eot|>" , 0 , 0 ).get ()[0 ]);
182176 eos_ids->insert (tokenizer_->encode (" <|end_of_text|>" , 0 , 0 ).get ()[0 ]);
183177 } else {
184- tokenizer_ = load_llama_tokenizer (tokenizer_path_, Version::Default);
178+ tokenizer_ =
179+ example::load_llama_tokenizer (tokenizer_path_, Version::Default);
185180 if (tokenizer_ == nullptr ) {
186181 ET_LOG (
187182 Error, " Failed to load tokenizer with %s" , tokenizer_path_.c_str ());
@@ -323,13 +318,32 @@ Error Runner<T>::load() {
323318
324319template <typename T>
325320Error Runner<T>::generate(
321+ const std::string& prompt,
322+ const llm::GenerationConfig& config,
323+ std::function<void (const std::string&)> token_callback,
324+ std::function<void (const Stats&)> stats_callback) {
325+ return generate_from_pos (prompt, 0 , config, token_callback, stats_callback);
326+ }
327+
328+ template <typename T>
329+ Error Runner<T>::generate_from_pos(
330+ const std::string& prompt,
331+ int64_t start_pos,
332+ const llm::GenerationConfig& config,
333+ std::function<void (const std::string&)> token_callback,
334+ std::function<void (const Stats&)> stats_callback) {
335+ // TODO: currently only support start_pos == 0
336+ return generate_from_prompt_or_file (
337+ prompt, false , config, token_callback, stats_callback);
338+ }
339+
340+ template <typename T>
341+ Error Runner<T>::generate_from_prompt_or_file(
326342 const std::string& prompt,
327343 bool tokenized_prompt,
328- int32_t seq_len ,
344+ const llm::GenerationConfig& config ,
329345 std::function<void (const std::string&)> token_callback,
330- std::function<void (const Stats&)> stats_callback,
331- bool echo,
332- bool warming) {
346+ std::function<void (const Stats&)> stats_callback) {
333347 ET_CHECK_MSG (!prompt.empty (), " prompt cannot be null" );
334348 if (!is_loaded ()) {
335349 stats_.model_load_start_ms = time_in_ms ();
@@ -338,6 +352,7 @@ Error Runner<T>::generate(
338352 }
339353 stats_.inference_start_ms = time_in_ms ();
340354
355+ int32_t seq_len = config.seq_len ;
341356 seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
342357 int32_t n_bos = (cur_pos_ == 0 ) ? 1 : 0 ;
343358
@@ -376,7 +391,7 @@ Error Runner<T>::generate(
376391 " sequence length exceeded - please increase the seq_len value" );
377392
378393 // Prompt Processor first
379- if (token_callback) {
394+ if (token_callback && config. echo ) {
380395 token_callback (prompt);
381396 }
382397 bool dump_logits = dump_logits_path_.empty () ? false : true ;
0 commit comments