Skip to content

Commit 1123e3f

Browse files
billmguofacebook-github-bot
authored andcommitted
support ET dump for llama3 runner (pytorch#7507)
Summary: Support ET dump for llama3 runner to easy understand the regression performance issue Reviewed By: Andriyluck, limintang Differential Revision: D67656207
1 parent 6c9b9b6 commit 1123e3f

File tree

5 files changed

+68
-5
lines changed

5 files changed

+68
-5
lines changed

examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# model sharding with custom op
8-
set(CUSTOM_OP_SRCS_FILE
8+
set(CUSTOM_OP_SRCS_FILE
99
"${EXECUTORCH_SOURCE_DIR}/extension/llm/custom_ops/op_fallback.cpp"
1010
)
1111
add_library(custom_ops ${CUSTOM_OP_SRCS_FILE})
@@ -45,7 +45,7 @@ list(
4545
# build qnn llama3.2 1b runner
4646
add_executable(qnn_llama3_2_runner ${_llama3_2_runner__srcs})
4747
target_include_directories(
48-
qnn_llama3_2_runner PUBLIC ${_common_include_directories}
48+
qnn_llama3_2_runner PUBLIC ${_common_include_directories} ${EXECUTORCH_SOURCE_DIR}/devtools/etdump
4949
)
5050

5151
target_link_libraries(
@@ -58,6 +58,7 @@ target_link_libraries(
5858
gflags
5959
re2::re2
6060
custom_ops
61+
etdump
6162
)
6263
target_compile_options(
6364
qnn_llama3_2_runner PUBLIC ${_common_compile_options}

examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ DEFINE_int32(
4949
0,
5050
"0: PromptProcessor(prefill) / 1: TokenGenerator(kv) / 2: HybridMode (prefill+kv)");
5151

52+
DEFINE_bool(
53+
gen_etdump,
54+
false,
55+
"false: Disable ET dump/ True: Enable ET dump (default: false)");
56+
57+
5258
int main(int argc, char** argv) {
5359
gflags::ParseCommandLineFlags(&argc, &argv, true);
5460

@@ -57,7 +63,8 @@ int main(int argc, char** argv) {
5763
{FLAGS_model_path},
5864
FLAGS_tokenizer_path.c_str(),
5965
FLAGS_temperature,
60-
FLAGS_eval_mode);
66+
FLAGS_eval_mode,
67+
FLAGS_gen_etdump);
6168
std::vector<char> buf;
6269
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
6370
std::ofstream fout(FLAGS_output_path.c_str());

examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ Runner::Runner(
4141
const std::vector<std::string>& models_path,
4242
const std::string& tokenizer_path,
4343
const float temperature,
44-
const int eval_mode)
44+
const int eval_mode,
45+
const bool gen_etdump)
4546
: n_bos_(1),
4647
n_eos_(1),
4748
tokenizer_path_(tokenizer_path),
@@ -54,6 +55,27 @@ Runner::Runner(
5455
}
5556
ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str());
5657
ET_LOG(Info, "eval mode=%d", eval_mode);
58+
if (gen_etdump) {
59+
gen_etdump_ = true;
60+
switch(eval_mode) {
61+
case EvalMode::kPrefill:
62+
prefill_dump_ = std::make_unique<torch::executor::ETDumpGen>();
63+
break;
64+
case EvalMode::kKVCached:
65+
decode_dump_ = std::make_unique<torch::executor::ETDumpGen>();
66+
break;
67+
case EvalMode::kHybrid:
68+
prefill_dump_ = std::make_unique<torch::executor::ETDumpGen>();
69+
decode_dump_ = std::make_unique<torch::executor::ETDumpGen>();
70+
break;
71+
default:
72+
ET_CHECK_MSG(false, "Unsupported eval mode");
73+
break;
74+
}
75+
std::string etdump_dir = models_path[0].substr(0, models_path[0].find_last_of("/\\") + 1);
76+
prefill_etdump_path_ = etdump_dir + "prefill_etdump.etdp";
77+
decode_etdump_path_ = etdump_dir + "decode_etdump.etdp";
78+
}
5779
}
5880

5981
bool Runner::is_loaded() const {
@@ -91,9 +113,15 @@ Error Runner::load() {
91113

92114
for (std::shared_ptr<Module>& module : modules_) {
93115
if (!prefill_forward_name_.empty()) {
116+
if (gen_etdump_) {
117+
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(prefill_forward_name_, prefill_dump_.get()));
118+
}
94119
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(prefill_forward_name_));
95120
}
96121
if (!kv_forward_name_.empty()) {
122+
if (gen_etdump_) {
123+
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kv_forward_name_, decode_dump_.get()));
124+
}
97125
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kv_forward_name_));
98126
}
99127
}
@@ -395,6 +423,8 @@ Error Runner::generate(
395423

396424
stats_.num_prompt_tokens = num_prompt_tokens;
397425
stats_.num_generated_tokens = pos - num_prompt_tokens;
426+
if (gen_etdump_)
427+
gen_etdump_data();
398428
printReport(stats_);
399429
if (stats_callback) {
400430
stats_callback(stats_);
@@ -403,6 +433,22 @@ Error Runner::generate(
403433
return Error::Ok;
404434
}
405435

436+
void Runner::gen_etdump_data(){
437+
//dump the prefill and decode etdump data
438+
if (prefill_dump_.get() != nullptr) {
439+
torch::executor::etdump_result result = prefill_dump_->get_etdump_data();
440+
FILE* ptr = fopen(prefill_etdump_path_.c_str(), "w+");
441+
fwrite(result.buf, 1, result.size, ptr);
442+
fclose(ptr);
443+
}
444+
if (decode_dump_.get() != nullptr) {
445+
torch::executor::etdump_result result = decode_dump_->get_etdump_data();
446+
FILE* ptr = fopen(decode_etdump_path_.c_str(), "w+");
447+
fwrite(result.buf, 1, result.size, ptr);
448+
fclose(ptr);
449+
}
450+
}
451+
406452
namespace {
407453
void printReport(const Runner::Stats& stats) {
408454
printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str());

examples/qualcomm/oss_scripts/llama3_2/runner/runner.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <string>
1818
#include <unordered_map>
1919

20+
#include <executorch/devtools/etdump/etdump_flatcc.h>
2021
#include <executorch/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h>
2122
#include <executorch/extension/llm/sampler/sampler.h>
2223
#include <executorch/extension/llm/tokenizer/tokenizer.h>
@@ -30,7 +31,8 @@ class Runner {
3031
const std::vector<std::string>& models_path,
3132
const std::string& tokenizer_path,
3233
const float temperature,
33-
const int eval_mode);
34+
const int eval_mode,
35+
const bool gen_etdump);
3436

3537
struct Stats {
3638
// Scaling factor for timestamps - in this case, we use ms.
@@ -69,6 +71,7 @@ class Runner {
6971
void stop();
7072
std::vector<executorch::runtime::Result<executorch::runtime::MethodMeta>>
7173
get_methods_meta(std::string& method_name);
74+
void gen_etdump_data();
7275

7376
private:
7477
template <typename T>
@@ -93,6 +96,11 @@ class Runner {
9396
float temperature_;
9497
std::unique_ptr<executorch::extension::llm::Tokenizer> tokenizer_;
9598
std::unique_ptr<executorch::extension::llm::Sampler> sampler_;
99+
std::unique_ptr<torch::executor::ETDumpGen> prefill_dump_;
100+
std::unique_ptr<torch::executor::ETDumpGen> decode_dump_;
101+
bool gen_etdump_ = false;
102+
std::string prefill_etdump_path_;
103+
std::string decode_etdump_path_;
96104
Stats stats_;
97105
std::unique_ptr<Memory> io_mem_;
98106
EvalMode eval_mode_;

examples/qualcomm/oss_scripts/llama3_2/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def define_common_targets():
2828
"//executorch/extension/llm/tokenizer:bpe_tokenizer",
2929
"//executorch/extension/evalue_util:print_evalue",
3030
"//executorch/backends/qualcomm/runtime:runtime",
31+
"//executorch/devtools/etdump:etdump_flatcc",
3132
],
3233
external_deps = [
3334
"gflags",

0 commit comments

Comments
 (0)