@@ -41,7 +41,8 @@ Runner::Runner(
4141 const std::vector<std::string>& models_path,
4242 const std::string& tokenizer_path,
4343 const float temperature,
44- const int eval_mode)
44+ const int eval_mode,
45+ const bool gen_etdump)
4546 : n_bos_(1 ),
4647 n_eos_ (1 ),
4748 tokenizer_path_(tokenizer_path),
@@ -54,6 +55,28 @@ Runner::Runner(
5455 }
5556 ET_LOG (Info, " creating runner: tokenizer_path=%s" , tokenizer_path_.c_str ());
5657 ET_LOG (Info, " eval mode=%d" , eval_mode);
58+ if (gen_etdump) {
59+ gen_etdump_ = true ;
60+ switch (eval_mode) {
61+ case EvalMode::kPrefill :
62+ prefill_dump_ = std::make_unique<torch::executor::ETDumpGen>();
63+ break ;
64+ case EvalMode::kKVCached :
65+ decode_dump_ = std::make_unique<torch::executor::ETDumpGen>();
66+ break ;
67+ case EvalMode::kHybrid :
68+ prefill_dump_ = std::make_unique<torch::executor::ETDumpGen>();
69+ decode_dump_ = std::make_unique<torch::executor::ETDumpGen>();
70+ break ;
71+ default :
72+ ET_CHECK_MSG (false , " Unsupported eval mode" );
73+ break ;
74+ }
75+ std::string etdump_dir =
76+ models_path[0 ].substr (0 , models_path[0 ].find_last_of (" /\\ " ) + 1 );
77+ prefill_etdump_path_ = etdump_dir + " prefill_etdump.etdp" ;
78+ decode_etdump_path_ = etdump_dir + " decode_etdump.etdp" ;
79+ }
5780}
5881
5982bool Runner::is_loaded () const {
@@ -91,9 +114,17 @@ Error Runner::load() {
91114
92115 for (std::shared_ptr<Module>& module : modules_) {
93116 if (!prefill_forward_name_.empty ()) {
117+ if (gen_etdump_) {
118+ ET_CHECK_OK_OR_RETURN_ERROR (
119+ module ->load_method (prefill_forward_name_, prefill_dump_.get ()));
120+ }
94121 ET_CHECK_OK_OR_RETURN_ERROR (module ->load_method (prefill_forward_name_));
95122 }
96123 if (!kv_forward_name_.empty ()) {
124+ if (gen_etdump_) {
125+ ET_CHECK_OK_OR_RETURN_ERROR (
126+ module ->load_method (kv_forward_name_, decode_dump_.get ()));
127+ }
97128 ET_CHECK_OK_OR_RETURN_ERROR (module ->load_method (kv_forward_name_));
98129 }
99130 }
@@ -395,6 +426,8 @@ Error Runner::generate(
395426
396427 stats_.num_prompt_tokens = num_prompt_tokens;
397428 stats_.num_generated_tokens = pos - num_prompt_tokens;
429+ if (gen_etdump_)
430+ gen_etdump_data ();
398431 printReport (stats_);
399432 if (stats_callback) {
400433 stats_callback (stats_);
@@ -403,6 +436,22 @@ Error Runner::generate(
403436 return Error::Ok;
404437}
405438
439+ void Runner::gen_etdump_data () {
440+ // dump the prefill and decode etdump data
441+ if (prefill_dump_.get () != nullptr ) {
442+ torch::executor::etdump_result result = prefill_dump_->get_etdump_data ();
443+ FILE* ptr = fopen (prefill_etdump_path_.c_str (), " w+" );
444+ fwrite (result.buf , 1 , result.size , ptr);
445+ fclose (ptr);
446+ }
447+ if (decode_dump_.get () != nullptr ) {
448+ torch::executor::etdump_result result = decode_dump_->get_etdump_data ();
449+ FILE* ptr = fopen (decode_etdump_path_.c_str (), " w+" );
450+ fwrite (result.buf , 1 , result.size , ptr);
451+ fclose (ptr);
452+ }
453+ }
454+
406455namespace {
407456void printReport (const Runner::Stats& stats) {
408457 printf (" PyTorchObserver %s\n " , statsToJsonString (stats).c_str ());
0 commit comments