@@ -54,10 +54,10 @@ Runner::Runner(
5454 {kUseSDPAWithKVCache , false },
5555 }) {
5656 if (data_path.has_value ()) {
57- module_ = std::make_unique <Module>(
57+ module_ = std::make_shared <Module>(
5858 model_path, data_path.value (), Module::LoadMode::File);
5959 } else {
60- module_ = std::make_unique <Module>(model_path, Module::LoadMode::File);
60+ module_ = std::make_shared <Module>(model_path, Module::LoadMode::File);
6161 }
6262 ET_LOG (
6363 Info,
@@ -89,7 +89,7 @@ Error Runner::load() {
8989 ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
9090 // load tokenizer. Assuming tiktoken is the default tokenizer
9191 tokenizer_ = nullptr ;
92- tokenizer_ = get_tiktoken_for_llama ();
92+ tokenizer_ = get_tiktoken_for_llama< decltype (tokenizer_)> ();
9393 ::tokenizers::Error err = tokenizer_->load (tokenizer_path_);
9494 // Rely on tiktoken to throw error if the artifact is incompatible. Then we
9595 // fallback to BPE tokenizer.
@@ -99,7 +99,7 @@ Error Runner::load() {
9999 " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
100100 tokenizer_path_.c_str ());
101101 tokenizer_.reset ();
102- tokenizer_ = std::make_unique <::tokenizers::Llama2cTokenizer>();
102+ tokenizer_ = std::make_shared <::tokenizers::Llama2cTokenizer>();
103103 err = tokenizer_->load (tokenizer_path_);
104104 ET_CHECK_TK_OK_OR_RETURN_ERROR (
105105 err,
@@ -143,20 +143,21 @@ Error Runner::load() {
143143 }
144144 }
145145 // @lint-ignore CLANGTIDY facebook-hte-Deprecated
146- text_decoder_runner_ = std::make_unique <llm::TextDecoderRunner>(
147- module_. get () , metadata_.at (kUseKVCache ));
146+ text_decoder_runner_ = std::make_shared <llm::TextDecoderRunner>(
147+ module_, metadata_.at (kUseKVCache ));
148148 text_prefiller_ = std::make_unique<llm::TextPrefiller>(
149- text_decoder_runner_. get () ,
149+ text_decoder_runner_,
150150 metadata_.at (kUseKVCache ),
151151 metadata_.at (kEnableDynamicShape ),
152152 metadata_.at (kMaxSeqLen ));
153153
154+ stats_ = std::make_shared<llm::Stats>();
154155 text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
155- tokenizer_. get () ,
156- text_decoder_runner_. get () ,
156+ tokenizer_,
157+ text_decoder_runner_,
157158 metadata_.at (kUseKVCache ),
158159 std::move (eos_ids),
159- & stats_);
160+ stats_);
160161
161162 return Error::Ok;
162163}
@@ -178,9 +179,9 @@ Error Runner::generate(
178179 // Use ones-initialized inputs.
179180 ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
180181 if (!is_loaded ()) {
181- stats_. model_load_start_ms = llm::time_in_ms ();
182+ stats_-> model_load_start_ms = llm::time_in_ms ();
182183 ET_CHECK_OK_OR_RETURN_ERROR (load ());
183- stats_. model_load_end_ms = llm::time_in_ms ();
184+ stats_-> model_load_end_ms = llm::time_in_ms ();
184185 }
185186
186187 if (config.warming ) {
@@ -206,7 +207,7 @@ Error Runner::generate(
206207 // First token time only measures the time it takes to encode the prompt and
207208 // return a response token.
208209
209- stats_. inference_start_ms = llm::time_in_ms ();
210+ stats_-> inference_start_ms = llm::time_in_ms ();
210211 shouldStop_ = false ;
211212
212213 ::tokenizers::Result<std::vector<uint64_t >> encode_res = tokenizer_->encode (
@@ -247,8 +248,8 @@ Error Runner::generate(
247248 auto prefill_res = text_prefiller_->prefill (prompt_tokens, pos);
248249 ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
249250 uint64_t cur_token = prefill_res.get ();
250- stats_. first_token_ms = llm::time_in_ms ();
251- stats_. prompt_eval_end_ms = llm::time_in_ms ();
251+ stats_-> first_token_ms = llm::time_in_ms ();
252+ stats_-> prompt_eval_end_ms = llm::time_in_ms ();
252253
253254 // print the first token from prefill. No prev_token so use cur_token for it.
254255 wrapped_callback (
@@ -269,7 +270,7 @@ Error Runner::generate(
269270 temperature_ == -1 .0f ? config.temperature : temperature_,
270271 wrapped_callback));
271272
272- stats_. inference_end_ms = llm::time_in_ms ();
273+ stats_-> inference_end_ms = llm::time_in_ms ();
273274 if (!config.warming ) {
274275 printf (" \n " );
275276 }
@@ -282,17 +283,17 @@ Error Runner::generate(
282283 RUNNER_ET_LOG (config.warming , " Max new tokens %i reached!" , max_new_tokens);
283284 }
284285
285- stats_. num_prompt_tokens = num_prompt_tokens;
286- stats_. num_generated_tokens = num_generated_tokens;
286+ stats_-> num_prompt_tokens = num_prompt_tokens;
287+ stats_-> num_generated_tokens = num_generated_tokens;
287288
288289 if (config.warming ) {
289290 ET_LOG (Info, " Warmup run finished!" );
290291 } else {
291292 // Do not print report during warmup
292- ::executorch::llm::print_report (stats_);
293+ ::executorch::llm::print_report (* stats_);
293294 }
294295 if (stats_callback) {
295- stats_callback (stats_);
296+ stats_callback (* stats_);
296297 }
297298
298299 return Error::Ok;
@@ -307,7 +308,7 @@ Error Runner::warmup(const std::string& prompt, int32_t max_new_tokens) {
307308 Error err = generate (prompt, config);
308309
309310 // Reset stats after warmup
310- stats_. reset ();
311+ stats_-> reset ();
311312 return err;
312313}
313314
0 commit comments