@@ -41,13 +41,11 @@ static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
4141Runner::Runner (
4242 const std::string& model_path,
4343 const std::string& tokenizer_path,
44- const float temperature,
4544 std::optional<const std::string> data_path)
4645 // NOTE: we observed ~2x loading performance increase on iPhone 15
4746 // and a ~5% improvement on Galaxy S22 by switching to
4847 // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
49- : temperature_(temperature),
50- tokenizer_path_ (tokenizer_path),
48+ : tokenizer_path_(tokenizer_path),
5149 metadata_ ({
5250 {kEnableDynamicShape , false },
5351 {kMaxSeqLen , 128 },
@@ -68,6 +66,17 @@ Runner::Runner(
6866 tokenizer_path.c_str ());
6967}
7068
69+ [[deprecated(
70+ " This constructor is deprecated. Use the constructor without temperature parameter instead." )]]
71+ Runner::Runner (
72+ const std::string& model_path,
73+ const std::string& tokenizer_path,
74+ const float temperature,
75+ std::optional<const std::string> data_path)
76+ : Runner(model_path, tokenizer_path, data_path) {
77+ temperature_ = temperature;
78+ }
79+
7180bool Runner::is_loaded () const {
7281 return module_->is_loaded () && tokenizer_ && text_decoder_runner_ &&
7382 text_prefiller_ && text_token_generator_;
@@ -133,11 +142,9 @@ Error Runner::load() {
133142 ET_LOG (Info, " eos_id = %" PRId64, value);
134143 }
135144 }
145+ // @lint-ignore CLANGTIDY facebook-hte-Deprecated
136146 text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
137- module_.get (),
138- metadata_.at (kUseKVCache ),
139- metadata_.at (kVocabSize ),
140- temperature_);
147+ module_.get (), metadata_.at (kUseKVCache ));
141148 text_prefiller_ = std::make_unique<llm::TextPrefiller>(
142149 text_decoder_runner_.get (),
143150 metadata_.at (kUseKVCache ),
@@ -164,11 +171,9 @@ Error Runner::load() {
164171
165172Error Runner::generate (
166173 const std::string& prompt,
167- int32_t seq_len ,
174+ const ::executorch::extension::llm::GenerationConfig& config ,
168175 std::function<void (const std::string&)> token_callback,
169- std::function<void(const llm::Stats&)> stats_callback,
170- bool echo,
171- bool warmup) {
176+ std::function<void(const llm::Stats&)> stats_callback) {
172177 // Prepare the inputs.
173178 // Use ones-initialized inputs.
174179 ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
@@ -178,19 +183,19 @@ Error Runner::generate(
178183 stats_.model_load_end_ms = llm::time_in_ms ();
179184 }
180185
181- if (warmup ) {
186+ if (config. warming ) {
182187 ET_LOG (Info, " Doing a warmup run..." );
183188 }
184189
185190 RUNNER_ET_LOG (
186- warmup ,
191+ config. warming ,
187192 " RSS after loading model: %f MiB (0 if unsupported)" ,
188193 llm::get_rss_bytes () / 1024.0 / 1024.0 );
189194
190195 // Wrap the token_callback with print function
191196 std::function<void (const std::string&)> wrapped_callback =
192- [token_callback, warmup ](const std::string& piece) {
193- if (!warmup ) {
197+ [token_callback, config ](const std::string& piece) {
198+ if (!config. warming ) {
194199 llm::safe_printf (piece.c_str ());
195200 fflush (stdout);
196201 }
@@ -204,11 +209,6 @@ Error Runner::generate(
204209 stats_.inference_start_ms = llm::time_in_ms ();
205210 shouldStop_ = false ;
206211
207- // Set the sequence length to the max seq length if not provided
208- seq_len = (seq_len > 0 && seq_len <= metadata_.at (kMaxContextLen ))
209- ? seq_len
210- : metadata_.at (kMaxContextLen );
211-
212212 ::tokenizers::Result<std::vector<uint64_t >> encode_res = tokenizer_->encode (
213213 prompt,
214214 /* bos */ 0 ,
@@ -225,21 +225,22 @@ Error Runner::generate(
225225 ET_CHECK_MSG (
226226 num_prompt_tokens < metadata_.at (kMaxContextLen ),
227227 " num_prompt_tokens %d >= max_seq_len_ %" PRId64
228- " , Max seq length exceeded - please increase max seq len value in .../llama2/model.py " ,
228+ " , Max seq length exceeded - please increase max seq len value in your export script " ,
229229 num_prompt_tokens,
230230 metadata_.at (kMaxContextLen ));
231- ET_CHECK_MSG (
232- num_prompt_tokens < seq_len,
233- " num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()" ,
234- num_prompt_tokens,
235- seq_len);
231+
232+ // Determine max_new_tokens using the GenerationConfig's resolve method
233+ int max_new_tokens = config.resolve_max_new_tokens (
234+ metadata_.at (kMaxContextLen ), num_prompt_tokens);
235+
236+ ET_LOG (Info, " Max new tokens resolved: %d" , max_new_tokens);
236237
237238 // Prefill first
238239 // Here feed all tokens to the model and get the next predicted token
239240 // after the prompt. After that we will enter generate loop.
240241
241242 // print prompts
242- if (echo) {
243+ if (config. echo ) {
243244 wrapped_callback (prompt);
244245 }
245246 int64_t pos = 0 ;
@@ -253,32 +254,38 @@ Error Runner::generate(
253254 wrapped_callback (
254255 ET_UNWRAP_TOKENIZER (tokenizer_->decode (cur_token, cur_token)));
255256 RUNNER_ET_LOG (
256- warmup ,
257+ config. warming ,
257258 " RSS after prompt prefill: %f MiB (0 if unsupported)" ,
258259 llm::get_rss_bytes () / 1024.0 / 1024.0 );
259260
260261 // start the main loop
261262 prompt_tokens.push_back (cur_token);
263+
264+ // Generate max_new_tokens - 1 because prefill already generated 1 token.
262265 int64_t num_generated_tokens = ET_UNWRAP (text_token_generator_->generate (
263- prompt_tokens, num_prompt_tokens, seq_len, wrapped_callback));
266+ prompt_tokens,
267+ num_prompt_tokens,
268+ max_new_tokens - 1 ,
269+ temperature_ == -1 .0f ? config.temperature : temperature_,
270+ wrapped_callback));
264271
265272 stats_.inference_end_ms = llm::time_in_ms ();
266- if (!warmup ) {
273+ if (!config. warming ) {
267274 printf (" \n " );
268275 }
269276 RUNNER_ET_LOG (
270- warmup ,
277+ config. warming ,
271278 " RSS after finishing text generation: %f MiB (0 if unsupported)" ,
272279 llm::get_rss_bytes () / 1024.0 / 1024.0 );
273280
274- if (num_prompt_tokens + num_generated_tokens == seq_len ) {
275- RUNNER_ET_LOG (warmup , " Sequence length ( %i tokens) reached!" , seq_len );
281+ if (num_generated_tokens == max_new_tokens ) {
282+ RUNNER_ET_LOG (config. warming , " Max new tokens %i reached!" , max_new_tokens );
276283 }
277284
278285 stats_.num_prompt_tokens = num_prompt_tokens;
279286 stats_.num_generated_tokens = num_generated_tokens;
280287
281- if (warmup ) {
288+ if (config. warming ) {
282289 ET_LOG (Info, " Warmup run finished!" );
283290 } else {
284291 // Do not print report during warmup
@@ -291,14 +298,15 @@ Error Runner::generate(
291298 return Error::Ok;
292299}
293300
294- Error Runner::warmup (const std::string& prompt, int32_t seq_len) {
295- Error err = generate (
296- prompt,
297- seq_len,
298- /* token_callback=*/ nullptr ,
299- /* stats_callbak=*/ nullptr ,
300- /* echo=*/ false ,
301- /* warmup=*/ true );
301+ Error Runner::warmup (const std::string& prompt, int32_t max_new_tokens) {
302+ // Create a GenerationConfig for warmup
303+ llm::GenerationConfig config{
304+ .echo = false , .max_new_tokens = max_new_tokens, .warming = true };
305+
306+ // Call generate with the warmup config
307+ Error err = generate (prompt, config);
308+
309+ // Reset stats after warmup
302310 stats_.reset ();
303311 return err;
304312}
0 commit comments