@@ -40,6 +40,8 @@ limitations under the License.
4040#include " tflite/c/common.h"
4141#include " tflite/interpreter.h"
4242#include " tflite/profiling/model_runtime_info.h"
43+ #include " tflite/profiling/profile_buffer.h"
44+ #include " tflite/profiling/profile_summarizer.h"
4345#include " tflite/tools/benchmark/benchmark_model.h"
4446#include " tflite/tools/benchmark/benchmark_params.h"
4547#include " tflite/tools/benchmark/proto/benchmark_result.pb.h"
@@ -54,8 +56,13 @@ using ::tflite::tools::benchmark::BenchmarkResult;
5456class BenchmarkLoggingListener : public ::tflite::benchmark::BenchmarkListener {
5557 private:
5658 std::string result_file_path_ = " " ;
59+ tflite::profiling::ProfileSummarizer* run_summarizer_;
5760
5861 public:
62+ explicit BenchmarkLoggingListener (
63+ tflite::profiling::ProfileSummarizer* run_summarizer)
64+ : run_summarizer_(run_summarizer) {}
65+
5966 void OnBenchmarkStart (
6067 const ::tflite::benchmark::BenchmarkParams& params) override {
6168 if (!params.Get <std::string>(" result_file_path" ).empty ()) {
@@ -165,6 +172,11 @@ class BenchmarkLoggingListener : public ::tflite::benchmark::BenchmarkListener {
165172 result_file_path_.c_str ());
166173 }
167174 }
175+
176+ if (run_summarizer_) {
177+ LITERT_LOG (LITERT_INFO, " \n %s" ,
178+ run_summarizer_->GetOutputString ().c_str ());
179+ }
168180 }
169181};
170182
@@ -206,8 +218,7 @@ using ::tflite::utils::InputTensorData;
206218class BenchmarkLiteRtModel : public BenchmarkModel {
207219 public:
208220 explicit BenchmarkLiteRtModel (BenchmarkParams params = DefaultParams())
209- : BenchmarkModel(std::move(params)), log_output_() {
210- AddListener (&log_output_);
221+ : BenchmarkModel(std::move(params)) {
211222 model_runtime_info_listener_ = nullptr ;
212223 }
213224 ~BenchmarkLiteRtModel () override = default ;
@@ -336,10 +347,28 @@ class BenchmarkLiteRtModel : public BenchmarkModel {
336347 TfLiteStatus ResetInputsAndOutputs () override {
337348 if (profiler_) {
338349 profiler_.StopProfiling ();
350+
339351 auto events = profiler_.GetEvents ();
352+ std::vector<std::unique_ptr<tflite::profiling::ProfileEvent>>
353+ tflite_events;
354+ std::vector<const tflite::profiling::ProfileEvent*> tflite_ptr_events;
355+ tflite_events.reserve (events->size ());
356+ tflite_ptr_events.reserve (events->size ());
340357 for (const auto & event : *events) {
341- LITERT_LOG (LITERT_INFO, " Event: %s" , event.tag );
358+ auto tflite_event = std::make_unique<tflite::profiling::ProfileEvent>();
359+ // Refer litert/litert/runtime/profiler.cc
360+ tflite_event->tag = event.tag ;
361+ tflite_event->begin_timestamp_us = event.start_timestamp_us ;
362+ tflite_event->elapsed_time = event.elapsed_time_us ;
363+ tflite_event->event_type = event.event_type ;
364+ tflite_event->event_metadata = event.event_metadata1 ;
365+ tflite_event->extra_event_metadata = event.event_metadata2 ;
366+ tflite_event->begin_mem_usage = event.begin_mem_usage ;
367+ tflite_event->end_mem_usage = event.end_mem_usage ;
368+ tflite_ptr_events.push_back (tflite_event.get ());
369+ tflite_events.push_back (std::move (tflite_event));
342370 }
371+ run_summarizer_->ProcessProfiles (tflite_ptr_events, *interpreter_);
343372 profiler_.Reset ();
344373 profiler_.StartProfiling ();
345374 }
@@ -394,8 +423,12 @@ class BenchmarkLiteRtModel : public BenchmarkModel {
394423 std::unique_ptr<std::vector<litert::TensorBuffer>> input_buffers_;
395424 std::unique_ptr<std::vector<litert::TensorBuffer>> output_buffers_;
396425 litert::Profiler profiler_;
397- BenchmarkLoggingListener log_output_;
426+ std::unique_ptr< BenchmarkLoggingListener> log_output_;
398427 std::unique_ptr<ModelRuntimeInfoListener> model_runtime_info_listener_;
428+
429+ // TFLite Interpreter is needed for run_summarizer_
430+ ::tflite::Interpreter* interpreter_ = nullptr ;
431+ std::unique_ptr<tflite::profiling::ProfileSummarizer> run_summarizer_;
399432};
400433
401434} // namespace benchmark
0 commit comments