Skip to content

Commit 660ef77

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Add warmup for Llama (#5756)
Summary: Load the model. Run the everything twice. Reset stats in between two runs. Also decrease logging level for warm up. Notes: Tested on Android and Mac. With Llama2 and Llama3 - with temperature=0 produces same output. This warm up option is disabled by default. This is inspired from llama.cpp options [[1](https://github.com/ggerganov/llama.cpp/blob/ea9c32be71b91b42ecc538bd902e93cbb5fb36cb/common/common.cpp#L897-L929), [2](https://github.com/ggerganov/llama.cpp/blob/ea9c32be71b91b42ecc538bd902e93cbb5fb36cb/examples/llama-bench/llama-bench.cpp#L1595-L1602)]. Sample [runs](https://www.internalfb.com/phabricator/paste/view/P1613261035) Pull Request resolved: pytorch/executorch#5756 Reviewed By: mcr229, metascroy Differential Revision: D63642723 Pulled By: digantdesai fbshipit-source-id: 39ff257eda182fff423f90582a9f32387cfdb253
1 parent 26dc9fd commit 660ef77

File tree

6 files changed

+76
-16
lines changed

6 files changed

+76
-16
lines changed

.ci/scripts/test_llama.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ echo "Creating tokenizer.bin"
213213
$PYTHON_EXECUTABLE -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
214214

215215

216-
RUNTIME_ARGS="--model_path=${EXPORTED_MODEL_NAME} --tokenizer_path=tokenizer.bin --prompt=Once --temperature=0 --seq_len=10"
216+
RUNTIME_ARGS="--model_path=${EXPORTED_MODEL_NAME} --tokenizer_path=tokenizer.bin --prompt=Once --temperature=0 --seq_len=10 --warmup=1"
217217
# Check build tool.
218218
echo "Running ${EXPORTED_MODEL_NAME} in portable mode"
219219
if [[ "${BUILD_TOOL}" == "buck2" ]]; then

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
.hypothesis
22
buck-out/
3-
cmake-out/
3+
cmake-out*
4+
.DS_Store
45
cmake-android-out/
56
cmake-out-android/
67
cmake-ios-out/

examples/models/llama2/main.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ DEFINE_int32(
3939
-1,
4040
"Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");
4141

42+
DEFINE_bool(warmup, false, "Whether to run a warmup run.");
43+
4244
int32_t main(int32_t argc, char** argv) {
4345
gflags::ParseCommandLineFlags(&argc, &argv, true);
4446

@@ -57,6 +59,8 @@ int32_t main(int32_t argc, char** argv) {
5759

5860
int32_t cpu_threads = FLAGS_cpu_threads;
5961

62+
bool warmup = FLAGS_warmup;
63+
6064
#if defined(ET_USE_THREADPOOL)
6165
uint32_t num_performant_cores = cpu_threads == -1
6266
? torch::executorch::cpuinfo::get_num_performant_cores()
@@ -71,6 +75,9 @@ int32_t main(int32_t argc, char** argv) {
7175
// create llama runner
7276
example::Runner runner(model_path, tokenizer_path, temperature);
7377

78+
if (warmup) {
79+
runner.warmup(prompt, seq_len);
80+
}
7481
// generate
7582
runner.generate(prompt, seq_len);
7683

examples/models/llama2/runner/runner.cpp

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,21 @@ Error Runner::load() {
146146
return Error::Ok;
147147
}
148148

149+
// Don't print with the same priority during warmup
150+
#define RUNNER_ET_LOG(warmup, format, ...) \
151+
if (warmup) { \
152+
ET_LOG(Debug, format, __VA_ARGS__); \
153+
} else { \
154+
ET_LOG(Info, format, __VA_ARGS__); \
155+
}
156+
149157
Error Runner::generate(
150158
const std::string& prompt,
151159
int32_t seq_len,
152160
std::function<void(const std::string&)> token_callback,
153161
std::function<void(const llm::Stats&)> stats_callback,
154-
bool echo) {
162+
bool echo,
163+
bool warmup) {
155164
// Prepare the inputs.
156165
// Use ones-initialized inputs.
157166
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
@@ -161,16 +170,22 @@ Error Runner::generate(
161170
stats_.model_load_end_ms = llm::time_in_ms();
162171
}
163172

164-
ET_LOG(
165-
Info,
173+
if (warmup) {
174+
ET_LOG(Info, "Doing a warmup run...");
175+
}
176+
177+
RUNNER_ET_LOG(
178+
warmup,
166179
"RSS after loading model: %f MiB (0 if unsupported)",
167180
llm::get_rss_bytes() / 1024.0 / 1024.0);
168181

169182
// Wrap the token_callback with print function
170183
std::function<void(const std::string&)> wrapped_callback =
171-
[token_callback](const std::string& piece) {
172-
llm::safe_printf(piece.c_str());
173-
fflush(stdout);
184+
[token_callback, warmup](const std::string& piece) {
185+
if (!warmup) {
186+
llm::safe_printf(piece.c_str());
187+
fflush(stdout);
188+
}
174189
if (token_callback) {
175190
token_callback(piece);
176191
}
@@ -228,8 +243,8 @@ Error Runner::generate(
228243

229244
// print the first token from prefill. No prev_token so use cur_token for it.
230245
wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token)));
231-
ET_LOG(
232-
Info,
246+
RUNNER_ET_LOG(
247+
warmup,
233248
"RSS after prompt prefill: %f MiB (0 if unsupported)",
234249
llm::get_rss_bytes() / 1024.0 / 1024.0);
235250

@@ -239,26 +254,46 @@ Error Runner::generate(
239254
prompt_tokens, num_prompt_tokens, seq_len, wrapped_callback));
240255

241256
stats_.inference_end_ms = llm::time_in_ms();
242-
printf("\n");
243-
ET_LOG(
244-
Info,
257+
if (!warmup) {
258+
printf("\n");
259+
}
260+
RUNNER_ET_LOG(
261+
warmup,
245262
"RSS after finishing text generation: %f MiB (0 if unsupported)",
246263
llm::get_rss_bytes() / 1024.0 / 1024.0);
247264

248265
if (num_prompt_tokens + num_generated_tokens == seq_len) {
249-
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
266+
RUNNER_ET_LOG(warmup, "Sequence length (%i tokens) reached!", seq_len);
250267
}
251268

252269
stats_.num_prompt_tokens = num_prompt_tokens;
253270
stats_.num_generated_tokens = num_generated_tokens;
254-
::executorch::llm::print_report(stats_);
271+
272+
if (warmup) {
273+
ET_LOG(Info, "Warmup run finished!");
274+
} else {
275+
// Do not print report during warmup
276+
::executorch::llm::print_report(stats_);
277+
}
255278
if (stats_callback) {
256279
stats_callback(stats_);
257280
}
258281

259282
return Error::Ok;
260283
}
261284

285+
Error Runner::warmup(const std::string& prompt, int32_t seq_len) {
286+
Error err = generate(
287+
prompt,
288+
seq_len,
289+
/*token_callback=*/nullptr,
290+
/*stats_callbak=*/nullptr,
291+
/*echo=*/false,
292+
/*warmup=*/true);
293+
stats_.reset();
294+
return err;
295+
}
296+
262297
void Runner::stop() {
263298
if (is_loaded()) {
264299
text_token_generator_->stop();

examples/models/llama2/runner/runner.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ class Runner {
4141
std::function<void(const std::string&)> token_callback = {},
4242
std::function<void(const ::executorch::extension::llm::Stats&)>
4343
stats_callback = {},
44-
bool echo = true);
44+
bool echo = true,
45+
bool warming = false);
46+
::executorch::runtime::Error warmup(
47+
const std::string& prompt,
48+
int32_t seq_len = 128);
4549
void stop();
4650

4751
private:

extension/llm/runner/stats.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,19 @@ struct Stats {
5252
aggregate_sampling_timer_start_timestamp = 0;
5353
}
5454

55+
void reset() {
56+
model_load_start_ms = 0;
57+
model_load_end_ms = 0;
58+
inference_start_ms = 0;
59+
prompt_eval_end_ms = 0;
60+
first_token_ms = 0;
61+
inference_end_ms = 0;
62+
aggregate_sampling_time_ms = 0;
63+
num_prompt_tokens = 0;
64+
num_generated_tokens = 0;
65+
aggregate_sampling_timer_start_timestamp = 0;
66+
}
67+
5568
private:
5669
long aggregate_sampling_timer_start_timestamp = 0;
5770
};

0 commit comments

Comments
 (0)