@@ -146,12 +146,21 @@ Error Runner::load() {
146146 return Error::Ok;
147147}
148148
149+ // Don't print with the same priority during warmup
150+ #define RUNNER_ET_LOG (warmup, format, ...) \
151+ if (warmup) { \
152+ ET_LOG (Debug, format, __VA_ARGS__); \
153+ } else { \
154+ ET_LOG (Info, format, __VA_ARGS__); \
155+ }
156+
149157Error Runner::generate (
150158 const std::string& prompt,
151159 int32_t seq_len,
152160 std::function<void (const std::string&)> token_callback,
153161 std::function<void(const llm::Stats&)> stats_callback,
154- bool echo) {
162+ bool echo,
163+ bool warmup) {
155164 // Prepare the inputs.
156165 // Use ones-initialized inputs.
157166 ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
@@ -161,16 +170,22 @@ Error Runner::generate(
161170 stats_.model_load_end_ms = llm::time_in_ms ();
162171 }
163172
164- ET_LOG (
165- Info,
173+ if (warmup) {
174+ ET_LOG (Info, " Doing a warmup run..." );
175+ }
176+
177+ RUNNER_ET_LOG (
178+ warmup,
166179 " RSS after loading model: %f MiB (0 if unsupported)" ,
167180 llm::get_rss_bytes () / 1024.0 / 1024.0 );
168181
169182 // Wrap the token_callback with print function
170183 std::function<void (const std::string&)> wrapped_callback =
171- [token_callback](const std::string& piece) {
172- llm::safe_printf (piece.c_str ());
173- fflush (stdout);
184+ [token_callback, warmup](const std::string& piece) {
185+ if (!warmup) {
186+ llm::safe_printf (piece.c_str ());
187+ fflush (stdout);
188+ }
174189 if (token_callback) {
175190 token_callback (piece);
176191 }
@@ -228,8 +243,8 @@ Error Runner::generate(
228243
229244 // print the first token from prefill. No prev_token so use cur_token for it.
230245 wrapped_callback (ET_UNWRAP (tokenizer_->decode (cur_token, cur_token)));
231- ET_LOG (
232- Info ,
246+ RUNNER_ET_LOG (
247+ warmup ,
233248 " RSS after prompt prefill: %f MiB (0 if unsupported)" ,
234249 llm::get_rss_bytes () / 1024.0 / 1024.0 );
235250
@@ -239,26 +254,46 @@ Error Runner::generate(
239254 prompt_tokens, num_prompt_tokens, seq_len, wrapped_callback));
240255
241256 stats_.inference_end_ms = llm::time_in_ms ();
242- printf (" \n " );
243- ET_LOG (
244- Info,
257+ if (!warmup) {
258+ printf (" \n " );
259+ }
260+ RUNNER_ET_LOG (
261+ warmup,
245262 " RSS after finishing text generation: %f MiB (0 if unsupported)" ,
246263 llm::get_rss_bytes () / 1024.0 / 1024.0 );
247264
248265 if (num_prompt_tokens + num_generated_tokens == seq_len) {
249- ET_LOG (Info , " Sequence length (%i tokens) reached!" , seq_len);
266+ RUNNER_ET_LOG (warmup , " Sequence length (%i tokens) reached!" , seq_len);
250267 }
251268
252269 stats_.num_prompt_tokens = num_prompt_tokens;
253270 stats_.num_generated_tokens = num_generated_tokens;
254- ::executorch::llm::print_report (stats_);
271+
272+ if (warmup) {
273+ ET_LOG (Info, " Warmup run finished!" );
274+ } else {
275+ // Do not print report during warmup
276+ ::executorch::llm::print_report (stats_);
277+ }
255278 if (stats_callback) {
256279 stats_callback (stats_);
257280 }
258281
259282 return Error::Ok;
260283}
261284
285+ Error Runner::warmup (const std::string& prompt, int32_t seq_len) {
286+ Error err = generate (
287+ prompt,
288+ seq_len,
289+ /* token_callback=*/ nullptr ,
290+ /* stats_callbak=*/ nullptr ,
291+ /* echo=*/ false ,
292+ /* warmup=*/ true );
293+ stats_.reset ();
294+ return err;
295+ }
296+
262297void Runner::stop () {
263298 if (is_loaded ()) {
264299 text_token_generator_->stop ();
0 commit comments