99// A llama 3.2 runner that includes preprocessing and post processing
1010// logic. The module takes in a string as input and emits a string as output.
1111
12+ #include < executorch/examples/models/llama/runner/runner.h>
1213#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
1314#include < executorch/examples/qualcomm/oss_scripts/llama/runner/client_mem.h>
1415#include < executorch/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h>
@@ -59,7 +60,7 @@ void print_performance_report(
5960 outfile << num_tok;
6061 outfile.close ();
6162 } else {
62- ET_CHECK_MSG ( false , " Error saving the inference speed file" );
63+ ET_LOG (Error , " Error saving the inference speed file" );
6364 }
6465}
6566
@@ -84,13 +85,6 @@ void save_logits(
8485
8586} // namespace
8687
87- std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer (
88- const std::string& tokenizer_path,
89- Version version) {
90- auto special_tokens = get_special_tokens (version);
91- return llm::load_tokenizer (tokenizer_path, std::move (special_tokens));
92- }
93-
9488Runner::Runner (
9589 const std::string& decoder_model_version,
9690 const std::string& model_path,
@@ -175,7 +169,8 @@ Error Runner::load() {
175169 eos_ids->insert (tokenizer_->encode (" <|eot|>" , 0 , 0 ).get ()[0 ]);
176170 eos_ids->insert (tokenizer_->encode (" <|end_of_text|>" , 0 , 0 ).get ()[0 ]);
177171 } else {
178- tokenizer_ = load_llama_tokenizer (tokenizer_path_, Version::Default);
172+ tokenizer_ =
173+ example::load_llama_tokenizer (tokenizer_path_, Version::Default);
179174 if (tokenizer_ == nullptr ) {
180175 ET_LOG (
181176 Error, " Failed to load tokenizer with %s" , tokenizer_path_.c_str ());
@@ -313,13 +308,30 @@ Error Runner::load() {
313308}
314309
315310Error Runner::generate (
311+ const std::string& prompt,
312+ const llm::GenerationConfig& config,
313+ std::function<void (const std::string&)> token_callback,
314+ std::function<void(const Stats&)> stats_callback) {
315+ return generate_from_pos (prompt, 0 , config, token_callback, stats_callback);
316+ }
317+
318+ Error Runner::generate_from_pos (
319+ const std::string& prompt,
320+ int64_t start_pos,
321+ const llm::GenerationConfig& config,
322+ std::function<void (const std::string&)> token_callback,
323+ std::function<void(const Stats&)> stats_callback) {
324+ // TODO: currently only support start_pos == 0
325+ return generate_from_prompt_or_file (
326+ prompt, false , config, token_callback, stats_callback);
327+ }
328+
329+ Error Runner::generate_from_prompt_or_file (
316330 const std::string& prompt,
317331 bool tokenized_prompt,
318- int32_t seq_len ,
332+ const llm::GenerationConfig& config ,
319333 std::function<void (const std::string&)> token_callback,
320- std::function<void(const Stats&)> stats_callback,
321- bool echo,
322- bool warming) {
334+ std::function<void(const Stats&)> stats_callback) {
323335 ET_CHECK_MSG (!prompt.empty (), " prompt cannot be null" );
324336 if (!is_loaded ()) {
325337 stats_.model_load_start_ms = time_in_ms ();
@@ -328,6 +340,7 @@ Error Runner::generate(
328340 }
329341 stats_.inference_start_ms = time_in_ms ();
330342
343+ int32_t seq_len = config.seq_len ;
331344 seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
332345 int32_t n_bos = (cur_pos_ == 0 ) ? 1 : 0 ;
333346
@@ -366,7 +379,7 @@ Error Runner::generate(
366379 " sequence length exceeded - please increase the seq_len value" );
367380
368381 // Prompt Processor first
369- if (token_callback) {
382+ if (token_callback && config. echo ) {
370383 token_callback (prompt);
371384 }
372385 bool dump_logits = dump_logits_path_.empty () ? false : true ;
0 commit comments