@@ -41,7 +41,8 @@ Runner::Runner(
4141 const std::vector<std::string>& models_path,
4242 const std::string& tokenizer_path,
4343 const float temperature,
44- const int eval_mode)
44+ const int eval_mode,
45+ const bool gen_etdump)
4546 : n_bos_(1 ),
4647 n_eos_ (1 ),
4748 tokenizer_path_(tokenizer_path),
@@ -54,6 +55,27 @@ Runner::Runner(
5455 }
5556 ET_LOG (Info, " creating runner: tokenizer_path=%s" , tokenizer_path_.c_str ());
5657 ET_LOG (Info, " eval mode=%d" , eval_mode);
58+ if (gen_etdump) {
59+ gen_etdump_ = true ;
60+ switch (eval_mode) {
61+ case EvalMode::kPrefill :
62+ prefill_dump_ = std::make_unique<torch::executor::ETDumpGen>();
63+ break ;
64+ case EvalMode::kKVCached :
65+ decode_dump_ = std::make_unique<torch::executor::ETDumpGen>();
66+ break ;
67+ case EvalMode::kHybrid :
68+ prefill_dump_ = std::make_unique<torch::executor::ETDumpGen>();
69+ decode_dump_ = std::make_unique<torch::executor::ETDumpGen>();
70+ break ;
71+ default :
72+ ET_CHECK_MSG (false , " Unsupported eval mode" );
73+ break ;
74+ }
75+ std::string etdump_dir = models_path[0 ].substr (0 , models_path[0 ].find_last_of (" /\\ " ) + 1 );
76+ prefill_etdump_path_ = etdump_dir + " prefill_etdump.etdp" ;
77+ decode_etdump_path_ = etdump_dir + " decode_etdump.etdp" ;
78+ }
5779}
5880
5981bool Runner::is_loaded () const {
@@ -91,9 +113,15 @@ Error Runner::load() {
91113
92114 for (std::shared_ptr<Module>& module : modules_) {
93115 if (!prefill_forward_name_.empty ()) {
116+ if (gen_etdump_) {
117+ ET_CHECK_OK_OR_RETURN_ERROR (module ->load_method (prefill_forward_name_, prefill_dump_.get ()));
118+ }
94119 ET_CHECK_OK_OR_RETURN_ERROR (module ->load_method (prefill_forward_name_));
95120 }
96121 if (!kv_forward_name_.empty ()) {
122+ if (gen_etdump_) {
123+ ET_CHECK_OK_OR_RETURN_ERROR (module ->load_method (kv_forward_name_, decode_dump_.get ()));
124+ }
97125 ET_CHECK_OK_OR_RETURN_ERROR (module ->load_method (kv_forward_name_));
98126 }
99127 }
@@ -395,6 +423,8 @@ Error Runner::generate(
395423
396424 stats_.num_prompt_tokens = num_prompt_tokens;
397425 stats_.num_generated_tokens = pos - num_prompt_tokens;
426+ if (gen_etdump_)
427+ gen_etdump_data ();
398428 printReport (stats_);
399429 if (stats_callback) {
400430 stats_callback (stats_);
@@ -403,6 +433,22 @@ Error Runner::generate(
403433 return Error::Ok;
404434}
405435
436+ void Runner::gen_etdump_data (){
437+ // dump the prefill and decode etdump data
438+ if (prefill_dump_.get () != nullptr ) {
439+ torch::executor::etdump_result result = prefill_dump_->get_etdump_data ();
440+ FILE* ptr = fopen (prefill_etdump_path_.c_str (), " w+" );
441+ fwrite (result.buf , 1 , result.size , ptr);
442+ fclose (ptr);
443+ }
444+ if (decode_dump_.get () != nullptr ) {
445+ torch::executor::etdump_result result = decode_dump_->get_etdump_data ();
446+ FILE* ptr = fopen (decode_etdump_path_.c_str (), " w+" );
447+ fwrite (result.buf , 1 , result.size , ptr);
448+ fclose (ptr);
449+ }
450+ }
451+
406452namespace {
407453void printReport (const Runner::Stats& stats) {
408454 printf (" PyTorchObserver %s\n " , statsToJsonString (stats).c_str ());
0 commit comments