Skip to content

Commit b8ef71e

Browse files
billmguofacebook-github-bot
authored andcommitted
support ET dump for llama3 runner (#7507)
Summary: Pull Request resolved: #7507 Support ET dump for llama3 runner to easy understand the regression performance issue Reviewed By: Andriyluck, limintang Differential Revision: D67656207
1 parent 1bac885 commit b8ef71e

File tree

5 files changed

+71
-5
lines changed

5 files changed

+71
-5
lines changed

examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# model sharding with custom op
8-
set(CUSTOM_OP_SRCS_FILE
8+
set(CUSTOM_OP_SRCS_FILE
99
"${EXECUTORCH_SOURCE_DIR}/extension/llm/custom_ops/op_fallback.cpp"
1010
)
1111
add_library(custom_ops ${CUSTOM_OP_SRCS_FILE})
@@ -45,7 +45,7 @@ list(
4545
# build qnn llama3.2 1b runner
4646
add_executable(qnn_llama3_2_runner ${_llama3_2_runner__srcs})
4747
target_include_directories(
48-
qnn_llama3_2_runner PUBLIC ${_common_include_directories}
48+
qnn_llama3_2_runner PUBLIC ${_common_include_directories} ${EXECUTORCH_SOURCE_DIR}/devtools/etdump
4949
)
5050

5151
target_link_libraries(
@@ -58,6 +58,8 @@ target_link_libraries(
5858
gflags
5959
re2::re2
6060
custom_ops
61+
etdump
62+
${FLATCCRT_LIB}
6163
)
6264
target_compile_options(
6365
qnn_llama3_2_runner PUBLIC ${_common_compile_options}

examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ DEFINE_int32(
5151
DEFINE_double(logits_scale, 0.0, "Logits scale");
5252
DEFINE_int32(logits_offset, 0, "Logits offset");
5353

54+
DEFINE_bool(
55+
gen_etdump,
56+
false,
57+
"false: Disable ET dump/ True: Enable ET dump (default: false)");
58+
5459
int main(int argc, char** argv) {
5560
gflags::ParseCommandLineFlags(&argc, &argv, true);
5661

@@ -61,7 +66,8 @@ int main(int argc, char** argv) {
6166
FLAGS_logits_scale,
6267
FLAGS_logits_offset,
6368
FLAGS_temperature,
64-
FLAGS_eval_mode);
69+
FLAGS_eval_mode,
70+
FLAGS_gen_etdump);
6571
std::vector<char> buf;
6672
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
6773
std::ofstream fout(FLAGS_output_path.c_str());

examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ Runner::Runner(
4343
const float logits_scale,
4444
const int32_t logits_offset,
4545
const float temperature,
46-
const int eval_mode)
46+
const int eval_mode,
47+
const bool gen_etdump)
4748
: n_bos_(1),
4849
n_eos_(1),
4950
tokenizer_path_(tokenizer_path),
@@ -58,6 +59,28 @@ Runner::Runner(
5859
}
5960
ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str());
6061
ET_LOG(Info, "eval mode=%d", eval_mode);
62+
if (gen_etdump) {
63+
gen_etdump_ = true;
64+
switch (eval_mode) {
65+
case EvalMode::kPrefill:
66+
prefill_dump_ = std::make_unique<torch::executor::ETDumpGen>();
67+
break;
68+
case EvalMode::kKVCached:
69+
decode_dump_ = std::make_unique<torch::executor::ETDumpGen>();
70+
break;
71+
case EvalMode::kHybrid:
72+
prefill_dump_ = std::make_unique<torch::executor::ETDumpGen>();
73+
decode_dump_ = std::make_unique<torch::executor::ETDumpGen>();
74+
break;
75+
default:
76+
ET_CHECK_MSG(false, "Unsupported eval mode");
77+
break;
78+
}
79+
std::string etdump_dir =
80+
models_path[0].substr(0, models_path[0].find_last_of("/\\") + 1);
81+
prefill_etdump_path_ = etdump_dir + "prefill_etdump.etdp";
82+
decode_etdump_path_ = etdump_dir + "decode_etdump.etdp";
83+
}
6184
}
6285

6386
bool Runner::is_loaded() const {
@@ -95,9 +118,17 @@ Error Runner::load() {
95118

96119
for (std::shared_ptr<Module>& module : modules_) {
97120
if (!prefill_forward_name_.empty()) {
121+
if (gen_etdump_) {
122+
ET_CHECK_OK_OR_RETURN_ERROR(
123+
module->load_method(prefill_forward_name_, prefill_dump_.get()));
124+
}
98125
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(prefill_forward_name_));
99126
}
100127
if (!kv_forward_name_.empty()) {
128+
if (gen_etdump_) {
129+
ET_CHECK_OK_OR_RETURN_ERROR(
130+
module->load_method(kv_forward_name_, decode_dump_.get()));
131+
}
101132
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kv_forward_name_));
102133
}
103134
}
@@ -424,6 +455,8 @@ Error Runner::generate(
424455

425456
stats_.num_prompt_tokens = num_prompt_tokens;
426457
stats_.num_generated_tokens = pos - num_prompt_tokens;
458+
if (gen_etdump_)
459+
gen_etdump_data();
427460
printReport(stats_);
428461
if (stats_callback) {
429462
stats_callback(stats_);
@@ -432,6 +465,22 @@ Error Runner::generate(
432465
return Error::Ok;
433466
}
434467

468+
void Runner::gen_etdump_data() {
469+
// dump the prefill and decode etdump data
470+
if (prefill_dump_.get() != nullptr) {
471+
torch::executor::etdump_result result = prefill_dump_->get_etdump_data();
472+
FILE* ptr = fopen(prefill_etdump_path_.c_str(), "w+");
473+
fwrite(result.buf, 1, result.size, ptr);
474+
fclose(ptr);
475+
}
476+
if (decode_dump_.get() != nullptr) {
477+
torch::executor::etdump_result result = decode_dump_->get_etdump_data();
478+
FILE* ptr = fopen(decode_etdump_path_.c_str(), "w+");
479+
fwrite(result.buf, 1, result.size, ptr);
480+
fclose(ptr);
481+
}
482+
}
483+
435484
namespace {
436485
void printReport(const Runner::Stats& stats) {
437486
printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str());

examples/qualcomm/oss_scripts/llama3_2/runner/runner.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <string>
1818
#include <unordered_map>
1919

20+
#include <executorch/devtools/etdump/etdump_flatcc.h>
2021
#include <executorch/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h>
2122
#include <executorch/extension/llm/sampler/sampler.h>
2223
#include <executorch/extension/llm/tokenizer/tokenizer.h>
@@ -32,7 +33,8 @@ class Runner {
3233
const float logits_scale,
3334
const int32_t logits_offset,
3435
const float temperature,
35-
const int eval_mode);
36+
const int eval_mode,
37+
const bool gen_etdump);
3638

3739
struct Stats {
3840
// Scaling factor for timestamps - in this case, we use ms.
@@ -71,6 +73,7 @@ class Runner {
7173
void stop();
7274
std::vector<executorch::runtime::Result<executorch::runtime::MethodMeta>>
7375
get_methods_meta(std::string& method_name);
76+
void gen_etdump_data();
7477

7578
private:
7679
template <typename T>
@@ -98,6 +101,11 @@ class Runner {
98101
float temperature_;
99102
std::unique_ptr<executorch::extension::llm::Tokenizer> tokenizer_;
100103
std::unique_ptr<executorch::extension::llm::Sampler> sampler_;
104+
std::unique_ptr<torch::executor::ETDumpGen> prefill_dump_;
105+
std::unique_ptr<torch::executor::ETDumpGen> decode_dump_;
106+
bool gen_etdump_ = false;
107+
std::string prefill_etdump_path_;
108+
std::string decode_etdump_path_;
101109
Stats stats_;
102110
std::unique_ptr<Memory> io_mem_;
103111
EvalMode eval_mode_;

examples/qualcomm/oss_scripts/llama3_2/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def define_common_targets():
2828
"//executorch/extension/llm/tokenizer:bpe_tokenizer",
2929
"//executorch/extension/evalue_util:print_evalue",
3030
"//executorch/backends/qualcomm/runtime:runtime",
31+
"//executorch/devtools/etdump:etdump_flatcc",
3132
],
3233
external_deps = [
3334
"gflags",

0 commit comments

Comments
 (0)