Skip to content

Commit 25db0b5

Browse files
billmguofacebook-github-bot
authored andcommitted
support ET dump for llama3 runner (#7507)
Summary: Support ET dump for llama3 runner to easy understand the regression performance issue Reviewed By: Andriyluck, limintang Differential Revision: D67656207
1 parent 54f0786 commit 25db0b5

File tree

5 files changed

+70
-5
lines changed

5 files changed

+70
-5
lines changed

examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# model sharding with custom op
8-
set(CUSTOM_OP_SRCS_FILE
8+
set(CUSTOM_OP_SRCS_FILE
99
"${EXECUTORCH_SOURCE_DIR}/extension/llm/custom_ops/op_fallback.cpp"
1010
)
1111
add_library(custom_ops ${CUSTOM_OP_SRCS_FILE})
@@ -45,7 +45,7 @@ list(
4545
# build qnn llama3.2 1b runner
4646
add_executable(qnn_llama3_2_runner ${_llama3_2_runner__srcs})
4747
target_include_directories(
48-
qnn_llama3_2_runner PUBLIC ${_common_include_directories}
48+
qnn_llama3_2_runner PUBLIC ${_common_include_directories} ${EXECUTORCH_SOURCE_DIR}/devtools/etdump
4949
)
5050

5151
target_link_libraries(
@@ -58,6 +58,7 @@ target_link_libraries(
5858
gflags
5959
re2::re2
6060
custom_ops
61+
etdump
6162
)
6263
target_compile_options(
6364
qnn_llama3_2_runner PUBLIC ${_common_compile_options}

examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ DEFINE_int32(
4949
0,
5050
"0: PromptProcessor(prefill) / 1: TokenGenerator(kv) / 2: HybridMode (prefill+kv)");
5151

52+
DEFINE_bool(
53+
gen_etdump,
54+
false,
55+
"false: Disable ET dump/ True: Enable ET dump (default: false)");
56+
5257
int main(int argc, char** argv) {
5358
gflags::ParseCommandLineFlags(&argc, &argv, true);
5459

@@ -57,7 +62,8 @@ int main(int argc, char** argv) {
5762
{FLAGS_model_path},
5863
FLAGS_tokenizer_path.c_str(),
5964
FLAGS_temperature,
60-
FLAGS_eval_mode);
65+
FLAGS_eval_mode,
66+
FLAGS_gen_etdump);
6167
std::vector<char> buf;
6268
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
6369
std::ofstream fout(FLAGS_output_path.c_str());

examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ Runner::Runner(
4141
const std::vector<std::string>& models_path,
4242
const std::string& tokenizer_path,
4343
const float temperature,
44-
const int eval_mode)
44+
const int eval_mode,
45+
const bool gen_etdump)
4546
: n_bos_(1),
4647
n_eos_(1),
4748
tokenizer_path_(tokenizer_path),
@@ -54,6 +55,28 @@ Runner::Runner(
5455
}
5556
ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str());
5657
ET_LOG(Info, "eval mode=%d", eval_mode);
58+
if (gen_etdump) {
59+
gen_etdump_ = true;
60+
switch (eval_mode) {
61+
case EvalMode::kPrefill:
62+
prefill_dump_ = std::make_unique<torch::executor::ETDumpGen>();
63+
break;
64+
case EvalMode::kKVCached:
65+
decode_dump_ = std::make_unique<torch::executor::ETDumpGen>();
66+
break;
67+
case EvalMode::kHybrid:
68+
prefill_dump_ = std::make_unique<torch::executor::ETDumpGen>();
69+
decode_dump_ = std::make_unique<torch::executor::ETDumpGen>();
70+
break;
71+
default:
72+
ET_CHECK_MSG(false, "Unsupported eval mode");
73+
break;
74+
}
75+
std::string etdump_dir =
76+
models_path[0].substr(0, models_path[0].find_last_of("/\\") + 1);
77+
prefill_etdump_path_ = etdump_dir + "prefill_etdump.etdp";
78+
decode_etdump_path_ = etdump_dir + "decode_etdump.etdp";
79+
}
5780
}
5881

5982
bool Runner::is_loaded() const {
@@ -91,9 +114,17 @@ Error Runner::load() {
91114

92115
for (std::shared_ptr<Module>& module : modules_) {
93116
if (!prefill_forward_name_.empty()) {
117+
if (gen_etdump_) {
118+
ET_CHECK_OK_OR_RETURN_ERROR(
119+
module->load_method(prefill_forward_name_, prefill_dump_.get()));
120+
}
94121
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(prefill_forward_name_));
95122
}
96123
if (!kv_forward_name_.empty()) {
124+
if (gen_etdump_) {
125+
ET_CHECK_OK_OR_RETURN_ERROR(
126+
module->load_method(kv_forward_name_, decode_dump_.get()));
127+
}
97128
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kv_forward_name_));
98129
}
99130
}
@@ -395,6 +426,8 @@ Error Runner::generate(
395426

396427
stats_.num_prompt_tokens = num_prompt_tokens;
397428
stats_.num_generated_tokens = pos - num_prompt_tokens;
429+
if (gen_etdump_)
430+
gen_etdump_data();
398431
printReport(stats_);
399432
if (stats_callback) {
400433
stats_callback(stats_);
@@ -403,6 +436,22 @@ Error Runner::generate(
403436
return Error::Ok;
404437
}
405438

439+
void Runner::gen_etdump_data() {
440+
// dump the prefill and decode etdump data
441+
if (prefill_dump_.get() != nullptr) {
442+
torch::executor::etdump_result result = prefill_dump_->get_etdump_data();
443+
FILE* ptr = fopen(prefill_etdump_path_.c_str(), "w+");
444+
fwrite(result.buf, 1, result.size, ptr);
445+
fclose(ptr);
446+
}
447+
if (decode_dump_.get() != nullptr) {
448+
torch::executor::etdump_result result = decode_dump_->get_etdump_data();
449+
FILE* ptr = fopen(decode_etdump_path_.c_str(), "w+");
450+
fwrite(result.buf, 1, result.size, ptr);
451+
fclose(ptr);
452+
}
453+
}
454+
406455
namespace {
407456
void printReport(const Runner::Stats& stats) {
408457
printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str());

examples/qualcomm/oss_scripts/llama3_2/runner/runner.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <string>
1818
#include <unordered_map>
1919

20+
#include <executorch/devtools/etdump/etdump_flatcc.h>
2021
#include <executorch/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h>
2122
#include <executorch/extension/llm/sampler/sampler.h>
2223
#include <executorch/extension/llm/tokenizer/tokenizer.h>
@@ -30,7 +31,8 @@ class Runner {
3031
const std::vector<std::string>& models_path,
3132
const std::string& tokenizer_path,
3233
const float temperature,
33-
const int eval_mode);
34+
const int eval_mode,
35+
const bool gen_etdump);
3436

3537
struct Stats {
3638
// Scaling factor for timestamps - in this case, we use ms.
@@ -69,6 +71,7 @@ class Runner {
6971
void stop();
7072
std::vector<executorch::runtime::Result<executorch::runtime::MethodMeta>>
7173
get_methods_meta(std::string& method_name);
74+
void gen_etdump_data();
7275

7376
private:
7477
template <typename T>
@@ -93,6 +96,11 @@ class Runner {
9396
float temperature_;
9497
std::unique_ptr<executorch::extension::llm::Tokenizer> tokenizer_;
9598
std::unique_ptr<executorch::extension::llm::Sampler> sampler_;
99+
std::unique_ptr<torch::executor::ETDumpGen> prefill_dump_;
100+
std::unique_ptr<torch::executor::ETDumpGen> decode_dump_;
101+
bool gen_etdump_ = false;
102+
std::string prefill_etdump_path_;
103+
std::string decode_etdump_path_;
96104
Stats stats_;
97105
std::unique_ptr<Memory> io_mem_;
98106
EvalMode eval_mode_;

examples/qualcomm/oss_scripts/llama3_2/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def define_common_targets():
2828
"//executorch/extension/llm/tokenizer:bpe_tokenizer",
2929
"//executorch/extension/evalue_util:print_evalue",
3030
"//executorch/backends/qualcomm/runtime:runtime",
31+
"//executorch/devtools/etdump:etdump_flatcc",
3132
],
3233
external_deps = [
3334
"gflags",

0 commit comments

Comments
 (0)