@@ -52,12 +52,13 @@ Runner::Runner(
5252 {kMaxContextLen , 128 },
5353 {kUseKVCache , true },
5454 {kUseSDPAWithKVCache , false },
55- }) {
55+ }),
56+ stats_(std::make_shared<llm::Stats>()) {
5657 if (data_path.has_value ()) {
57- module_ = std::make_unique <Module>(
58+ module_ = std::make_shared <Module>(
5859 model_path, data_path.value (), Module::LoadMode::File);
5960 } else {
60- module_ = std::make_unique <Module>(model_path, Module::LoadMode::File);
61+ module_ = std::make_shared <Module>(model_path, Module::LoadMode::File);
6162 }
6263 ET_LOG (
6364 Info,
@@ -89,7 +90,7 @@ Error Runner::load() {
8990 ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
9091 // load tokenizer. Assuming tiktoken is the default tokenizer
9192 tokenizer_ = nullptr ;
92- tokenizer_ = get_tiktoken_for_llama ();
93+ tokenizer_ = get_tiktoken_for_llama< decltype (tokenizer_)> ();
9394 ::tokenizers::Error err = tokenizer_->load (tokenizer_path_);
9495 // Rely on tiktoken to throw error if the artifact is incompatible. Then we
9596 // fallback to BPE tokenizer.
@@ -99,7 +100,7 @@ Error Runner::load() {
99100 " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
100101 tokenizer_path_.c_str ());
101102 tokenizer_.reset ();
102- tokenizer_ = std::make_unique <::tokenizers::Llama2cTokenizer>();
103+ tokenizer_ = std::make_shared <::tokenizers::Llama2cTokenizer>();
103104 err = tokenizer_->load (tokenizer_path_);
104105 ET_CHECK_TK_OK_OR_RETURN_ERROR (
105106 err,
@@ -143,20 +144,20 @@ Error Runner::load() {
143144 }
144145 }
145146 // @lint-ignore CLANGTIDY facebook-hte-Deprecated
146- text_decoder_runner_ = std::make_unique <llm::TextDecoderRunner>(
147- module_. get () , metadata_.at (kUseKVCache ));
147+ text_decoder_runner_ = std::make_shared <llm::TextDecoderRunner>(
148+ module_, metadata_.at (kUseKVCache ));
148149 text_prefiller_ = std::make_unique<llm::TextPrefiller>(
149- text_decoder_runner_. get () ,
150+ text_decoder_runner_,
150151 metadata_.at (kUseKVCache ),
151152 metadata_.at (kEnableDynamicShape ),
152153 metadata_.at (kMaxSeqLen ));
153154
154155 text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
155- tokenizer_. get () ,
156- text_decoder_runner_. get () ,
156+ tokenizer_,
157+ text_decoder_runner_,
157158 metadata_.at (kUseKVCache ),
158159 std::move (eos_ids),
159- & stats_);
160+ stats_);
160161
161162 return Error::Ok;
162163}
@@ -178,9 +179,9 @@ Error Runner::generate(
178179 // Use ones-initialized inputs.
179180 ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
180181 if (!is_loaded ()) {
181- stats_. model_load_start_ms = llm::time_in_ms ();
182+ stats_-> model_load_start_ms = llm::time_in_ms ();
182183 ET_CHECK_OK_OR_RETURN_ERROR (load ());
183- stats_. model_load_end_ms = llm::time_in_ms ();
184+ stats_-> model_load_end_ms = llm::time_in_ms ();
184185 }
185186
186187 if (config.warming ) {
@@ -206,7 +207,7 @@ Error Runner::generate(
206207 // First token time only measures the time it takes to encode the prompt and
207208 // return a response token.
208209
209- stats_. inference_start_ms = llm::time_in_ms ();
210+ stats_-> inference_start_ms = llm::time_in_ms ();
210211 shouldStop_ = false ;
211212
212213 ::tokenizers::Result<std::vector<uint64_t >> encode_res = tokenizer_->encode (
@@ -247,8 +248,8 @@ Error Runner::generate(
247248 auto prefill_res = text_prefiller_->prefill (prompt_tokens, pos);
248249 ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
249250 uint64_t cur_token = prefill_res.get ();
250- stats_. first_token_ms = llm::time_in_ms ();
251- stats_. prompt_eval_end_ms = llm::time_in_ms ();
251+ stats_-> first_token_ms = llm::time_in_ms ();
252+ stats_-> prompt_eval_end_ms = llm::time_in_ms ();
252253
253254 // print the first token from prefill. No prev_token so use cur_token for it.
254255 wrapped_callback (
@@ -269,7 +270,7 @@ Error Runner::generate(
269270 temperature_ == -1 .0f ? config.temperature : temperature_,
270271 wrapped_callback));
271272
272- stats_. inference_end_ms = llm::time_in_ms ();
273+ stats_-> inference_end_ms = llm::time_in_ms ();
273274 if (!config.warming ) {
274275 printf (" \n " );
275276 }
@@ -282,17 +283,17 @@ Error Runner::generate(
282283 RUNNER_ET_LOG (config.warming , " Max new tokens %i reached!" , max_new_tokens);
283284 }
284285
285- stats_. num_prompt_tokens = num_prompt_tokens;
286- stats_. num_generated_tokens = num_generated_tokens;
286+ stats_-> num_prompt_tokens = num_prompt_tokens;
287+ stats_-> num_generated_tokens = num_generated_tokens;
287288
288289 if (config.warming ) {
289290 ET_LOG (Info, " Warmup run finished!" );
290291 } else {
291292 // Do not print report during warmup
292- ::executorch::llm::print_report (stats_);
293+ ::executorch::llm::print_report (* stats_);
293294 }
294295 if (stats_callback) {
295- stats_callback (stats_);
296+ stats_callback (* stats_);
296297 }
297298
298299 return Error::Ok;
@@ -307,7 +308,7 @@ Error Runner::warmup(const std::string& prompt, int32_t max_new_tokens) {
307308 Error err = generate (prompt, config);
308309
309310 // Reset stats after warmup
310- stats_. reset ();
311+ stats_-> reset ();
311312 return err;
312313}
313314
0 commit comments