Skip to content

Commit 1d5122f

Browse files
billmguofacebook-github-bot
authored andcommitted
support ET dump for llama3 runner
Summary: Support ET dump for llama3 runner to easy understand the regression performance issue Differential Revision: D67656207
1 parent c001634 commit 1d5122f

File tree

5 files changed

+69
-5
lines changed

5 files changed

+69
-5
lines changed

examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# model sharding with custom op
8-
set(CUSTOM_OP_SRCS_FILE
8+
set(CUSTOM_OP_SRCS_FILE
99
"${EXECUTORCH_SOURCE_DIR}/extension/llm/custom_ops/op_fallback.cpp"
1010
)
1111
add_library(custom_ops ${CUSTOM_OP_SRCS_FILE})
@@ -45,7 +45,7 @@ list(
4545
# build qnn llama3.2 1b runner
4646
add_executable(qnn_llama3_2_runner ${_llama3_2_runner__srcs})
4747
target_include_directories(
48-
qnn_llama3_2_runner PUBLIC ${_common_include_directories}
48+
qnn_llama3_2_runner PUBLIC ${_common_include_directories} ${EXECUTORCH_SOURCE_DIR}/devtools/etdump
4949
)
5050

5151
target_link_libraries(

examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ DEFINE_int32(
4949
0,
5050
"0: PromptProcessor(prefill) / 1: TokenGenerator(kv) / 2: HybridMode (prefill+kv)");
5151

52+
DEFINE_bool(
53+
gen_etdump,
54+
false,
55+
"false: Disable ET dump/ True: Enable ET dump (default: false)");
56+
57+
5258
int main(int argc, char** argv) {
5359
gflags::ParseCommandLineFlags(&argc, &argv, true);
5460

@@ -57,7 +63,8 @@ int main(int argc, char** argv) {
5763
{FLAGS_model_path},
5864
FLAGS_tokenizer_path.c_str(),
5965
FLAGS_temperature,
60-
FLAGS_eval_mode);
66+
FLAGS_eval_mode,
67+
FLAGS_gen_etdump);
6168
std::vector<char> buf;
6269
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
6370
std::ofstream fout(FLAGS_output_path.c_str());

examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ Runner::Runner(
4141
const std::vector<std::string>& models_path,
4242
const std::string& tokenizer_path,
4343
const float temperature,
44-
const int eval_mode)
44+
const int eval_mode,
45+
const bool gen_etdump)
4546
: n_bos_(1),
4647
n_eos_(1),
4748
tokenizer_path_(tokenizer_path),
@@ -54,6 +55,26 @@ Runner::Runner(
5455
}
5556
ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str());
5657
ET_LOG(Info, "eval mode=%d", eval_mode);
58+
if (gen_etdump) {
59+
gen_etdump_ = true;
60+
switch(eval_mode) {
61+
case 0:
62+
prefill_dump_ = new torch::executor::ETDumpGen();
63+
break;
64+
case 1:
65+
decode_dump_ = new torch::executor::ETDumpGen();
66+
break;
67+
case 2:
68+
prefill_dump_ = new torch::executor::ETDumpGen();
69+
decode_dump_ = new torch::executor::ETDumpGen();
70+
break;
71+
default:
72+
ET_CHECK_MSG(false, "Unsupported eval mode");
73+
break;
74+
}
75+
prefill_etdump_path_ = "prefill_etdump.etdp";
76+
decode_etdump_path_ = "decode_etdump.etdp";
77+
}
5778
}
5879

5980
bool Runner::is_loaded() const {
@@ -91,9 +112,15 @@ Error Runner::load() {
91112

92113
for (std::shared_ptr<Module>& module : modules_) {
93114
if (!prefill_forward_name_.empty()) {
115+
if (gen_etdump_) {
116+
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(prefill_forward_name_, prefill_dump_));
117+
}
94118
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(prefill_forward_name_));
95119
}
96120
if (!kv_forward_name_.empty()) {
121+
if (gen_etdump_) {
122+
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kv_forward_name_, decode_dump_));
123+
}
97124
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kv_forward_name_));
98125
}
99126
}
@@ -395,6 +422,8 @@ Error Runner::generate(
395422

396423
stats_.num_prompt_tokens = num_prompt_tokens;
397424
stats_.num_generated_tokens = pos - num_prompt_tokens;
425+
if (gen_etdump_)
426+
gen_etdump_data();
398427
printReport(stats_);
399428
if (stats_callback) {
400429
stats_callback(stats_);
@@ -403,6 +432,24 @@ Error Runner::generate(
403432
return Error::Ok;
404433
}
405434

435+
void Runner::gen_etdump_data(){
436+
//dump the prefill and decode etdump data
437+
if (prefill_dump_ != nullptr ) {
438+
torch::executor::etdump_result result = prefill_dump_->get_etdump_data();
439+
FILE* ptr = fopen(prefill_etdump_path_.c_str(), "w+");
440+
fwrite(result.buf, 1, result.size, ptr);
441+
fclose(ptr);
442+
prefill_dump_->reset();
443+
}
444+
if (decode_dump_ != nullptr ) {
445+
torch::executor::etdump_result result = decode_dump_->get_etdump_data();
446+
FILE* ptr = fopen(decode_etdump_path_.c_str(), "w+");
447+
fwrite(result.buf, 1, result.size, ptr);
448+
fclose(ptr);
449+
decode_dump_->reset();
450+
}
451+
}
452+
406453
namespace {
407454
void printReport(const Runner::Stats& stats) {
408455
printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str());

examples/qualcomm/oss_scripts/llama3_2/runner/runner.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include <executorch/extension/llm/sampler/sampler.h>
2222
#include <executorch/extension/llm/tokenizer/tokenizer.h>
2323
#include <executorch/extension/module/module.h>
24+
#include <executorch/devtools/etdump/etdump_flatcc.h>
25+
2426

2527
namespace example {
2628

@@ -30,7 +32,8 @@ class Runner {
3032
const std::vector<std::string>& models_path,
3133
const std::string& tokenizer_path,
3234
const float temperature,
33-
const int eval_mode);
35+
const int eval_mode,
36+
const bool gen_etdump);
3437

3538
struct Stats {
3639
// Scaling factor for timestamps - in this case, we use ms.
@@ -69,6 +72,7 @@ class Runner {
6972
void stop();
7073
std::vector<executorch::runtime::Result<executorch::runtime::MethodMeta>>
7174
get_methods_meta(std::string& method_name);
75+
void gen_etdump_data();
7276

7377
private:
7478
template <typename T>
@@ -93,6 +97,11 @@ class Runner {
9397
float temperature_;
9498
std::unique_ptr<executorch::extension::llm::Tokenizer> tokenizer_;
9599
std::unique_ptr<executorch::extension::llm::Sampler> sampler_;
100+
torch::executor::ETDumpGen* prefill_dump_ = nullptr;
101+
torch::executor::ETDumpGen* decode_dump_ = nullptr;
102+
bool gen_etdump_ = false;
103+
std::string prefill_etdump_path_;
104+
std::string decode_etdump_path_;
96105
Stats stats_;
97106
std::unique_ptr<Memory> io_mem_;
98107
EvalMode eval_mode_;

examples/qualcomm/oss_scripts/llama3_2/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def define_common_targets():
2828
"//executorch/extension/llm/tokenizer:bpe_tokenizer",
2929
"//executorch/extension/evalue_util:print_evalue",
3030
"//executorch/backends/qualcomm/runtime:runtime",
31+
"//executorch/devtools/etdump:etdump_flatcc",
3132
],
3233
external_deps = [
3334
"gflags",

0 commit comments

Comments
 (0)