Skip to content

Commit a5b1f83

Browse files
terryheocopybara-github
authored andcommitted
Enable GPU op profiling when use_profiler is used.
LiteRT-PiperOrigin-RevId: 818772840
1 parent 49a483d commit a5b1f83

File tree

5 files changed

+69
-19
lines changed

5 files changed

+69
-19
lines changed

litert/runtime/profiler.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ void LiteRtProfilerT::EndEvent(uint32_t event_handle) {
7676
}
7777

7878
void LiteRtProfilerT::AddEvent(const char* tag, EventType event_type,
79-
uint64_t elapsed_time, int64_t event_metadata1,
80-
int64_t event_metadata2) {
79+
uint64_t metric, int64_t event_metadata1,
80+
int64_t event_metadata2) {
8181
if (!profiling_enabled_ || !profile_buffer_) {
8282
return;
8383
}
@@ -86,8 +86,8 @@ void LiteRtProfilerT::AddEvent(const char* tag, EventType event_type,
8686
std::string s_tag(tag);
8787
auto [it, inserted] = owned_tags_set_.insert(std::move(s_tag));
8888
const char* owned_tag_ptr = it->c_str();
89-
profile_buffer_->AddEvent(owned_tag_ptr, event_type, elapsed_time,
90-
event_metadata1, event_metadata2);
89+
profile_buffer_->AddEvent(owned_tag_ptr, event_type, metric, event_metadata1,
90+
event_metadata2);
9191
}
9292

9393
void LiteRtProfilerT::StartProfiling() {

litert/runtime/profiler.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,12 @@ class LiteRtProfilerT : public tflite::Profiler {
5353

5454
// tag is copied and owned by the profiler, caller does not need to keep
5555
// the string alive.
56-
void AddEvent(const char* tag, EventType event_type, uint64_t
57-
elapsed_time, int64_t event_metadata1, int64_t event_metadata2) override;
56+
// `metric` field has different intreptation based on `event_type`.
57+
// e.g. it means elapsed time for [DELEGATE_]OPERATOR_INVOKE_EVENT types,
58+
// and interprets as source and status code for TELEMETRY_[DELEGATE_]EVENT
59+
// event types.
60+
void AddEvent(const char* tag, EventType event_type, uint64_t metric,
61+
int64_t event_metadata1, int64_t event_metadata2) override;
5862

5963
// Enables profiling. Events will start being recorded.
6064
void StartProfiling();

litert/tools/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ cc_library(
370370
"//tflite/c:c_api_types",
371371
"//tflite/c:common",
372372
"//tflite/profiling:model_runtime_info",
373+
"//tflite/profiling:profile_buffer",
374+
"//tflite/profiling:profile_summarizer",
373375
"//tflite/tools:command_line_flags",
374376
"//tflite/tools:utils",
375377
"//tflite/tools/benchmark:benchmark_model_lib",

litert/tools/benchmark_litert_model.cc

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ limitations under the License.
3939
#include "tflite/c/c_api_types.h"
4040
#include "tflite/c/common.h"
4141
#include "tflite/interpreter.h"
42+
#include "tflite/profiling/profile_summarizer.h"
4243

4344
namespace litert::benchmark {
4445
namespace {
@@ -89,6 +90,12 @@ Options CreateCompiledModelOptions(const BenchmarkParams& params) {
8990
if (gpu_low_priority) {
9091
gpu_options.SetGpuPriority(kLiteRtGpuPriorityLow);
9192
}
93+
94+
auto use_profiler = params.Get<bool>("use_profiler");
95+
if (use_profiler) {
96+
gpu_options.SetGpuPriority(kLiteRtGpuPriorityLow);
97+
}
98+
9299
compilation_options.AddOpaqueOptions(std::move(gpu_options));
93100
}
94101

@@ -167,25 +174,29 @@ TfLiteStatus BenchmarkLiteRtModel::Init() {
167174
compiled_model_ =
168175
std::make_unique<litert::CompiledModel>(std::move(compiled_model_result));
169176

177+
LiteRtCompiledModelT* compiled_model_ptr = compiled_model_->Get();
178+
if (compiled_model_ptr == nullptr) {
179+
LITERT_LOG(LITERT_ERROR, "Compiled model is null");
180+
return kTfLiteError;
181+
}
182+
LITERT_ASSIGN_OR_RETURN(interpreter_, GetInterpreter(compiled_model_ptr),
183+
AsTfLiteStatus(_ << "Failed to get interpreter."));
184+
170185
if (!params_.Get<std::string>("model_runtime_info_output_file").empty()) {
171-
::tflite::Interpreter* interpreter_ptr = nullptr;
172-
LiteRtCompiledModelT* compiled_model_ptr = compiled_model_->Get();
173-
if (compiled_model_ptr == nullptr) {
174-
LITERT_LOG(LITERT_ERROR, "Compiled model is null");
175-
return kTfLiteError;
176-
}
177-
LITERT_ASSIGN_OR_RETURN(interpreter_ptr, GetInterpreter(compiled_model_ptr),
178-
AsTfLiteStatus(_ << "Failed to get interpreter."));
179186
model_runtime_info_listener_ =
180-
std::make_unique<ModelRuntimeInfoListener>(interpreter_ptr);
187+
std::make_unique<ModelRuntimeInfoListener>(interpreter_);
181188
AddListener(model_runtime_info_listener_.get());
182189
}
183190

184191
auto use_profiler = params_.Get<bool>("use_profiler");
185192
if (use_profiler) {
193+
run_summarizer_ = std::make_unique<tflite::profiling::ProfileSummarizer>();
186194
LITERT_ASSIGN_OR_ABORT(profiler_, compiled_model_->GetProfiler());
187195
profiler_.StartProfiling();
188196
}
197+
log_output_ =
198+
std::make_unique<BenchmarkLoggingListener>(run_summarizer_.get());
199+
AddListener(log_output_.get());
189200

190201
auto signature = params_.Get<std::string>("signature_to_run_for");
191202
if (signature.empty()) {

litert/tools/benchmark_litert_model.h

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ limitations under the License.
4040
#include "tflite/c/common.h"
4141
#include "tflite/interpreter.h"
4242
#include "tflite/profiling/model_runtime_info.h"
43+
#include "tflite/profiling/profile_buffer.h"
44+
#include "tflite/profiling/profile_summarizer.h"
4345
#include "tflite/tools/benchmark/benchmark_model.h"
4446
#include "tflite/tools/benchmark/benchmark_params.h"
4547
#include "tflite/tools/benchmark/proto/benchmark_result.pb.h"
@@ -54,8 +56,13 @@ using ::tflite::tools::benchmark::BenchmarkResult;
5456
class BenchmarkLoggingListener : public ::tflite::benchmark::BenchmarkListener {
5557
private:
5658
std::string result_file_path_ = "";
59+
tflite::profiling::ProfileSummarizer* run_summarizer_;
5760

5861
public:
62+
explicit BenchmarkLoggingListener(
63+
tflite::profiling::ProfileSummarizer* run_summarizer)
64+
: run_summarizer_(run_summarizer) {}
65+
5966
void OnBenchmarkStart(
6067
const ::tflite::benchmark::BenchmarkParams& params) override {
6168
if (!params.Get<std::string>("result_file_path").empty()) {
@@ -165,6 +172,11 @@ class BenchmarkLoggingListener : public ::tflite::benchmark::BenchmarkListener {
165172
result_file_path_.c_str());
166173
}
167174
}
175+
176+
if (run_summarizer_) {
177+
LITERT_LOG(LITERT_INFO, "\n%s",
178+
run_summarizer_->GetOutputString().c_str());
179+
}
168180
}
169181
};
170182

@@ -206,8 +218,7 @@ using ::tflite::utils::InputTensorData;
206218
class BenchmarkLiteRtModel : public BenchmarkModel {
207219
public:
208220
explicit BenchmarkLiteRtModel(BenchmarkParams params = DefaultParams())
209-
: BenchmarkModel(std::move(params)), log_output_() {
210-
AddListener(&log_output_);
221+
: BenchmarkModel(std::move(params)) {
211222
model_runtime_info_listener_ = nullptr;
212223
}
213224
~BenchmarkLiteRtModel() override = default;
@@ -336,10 +347,28 @@ class BenchmarkLiteRtModel : public BenchmarkModel {
336347
TfLiteStatus ResetInputsAndOutputs() override {
337348
if (profiler_) {
338349
profiler_.StopProfiling();
350+
339351
auto events = profiler_.GetEvents();
352+
std::vector<std::unique_ptr<tflite::profiling::ProfileEvent>>
353+
tflite_events;
354+
std::vector<const tflite::profiling::ProfileEvent*> tflite_ptr_events;
355+
tflite_events.reserve(events->size());
356+
tflite_ptr_events.reserve(events->size());
340357
for (const auto& event : *events) {
341-
LITERT_LOG(LITERT_INFO, "Event: %s", event.tag);
358+
auto tflite_event = std::make_unique<tflite::profiling::ProfileEvent>();
359+
// Refer litert/litert/runtime/profiler.cc
360+
tflite_event->tag = event.tag;
361+
tflite_event->begin_timestamp_us = event.start_timestamp_us;
362+
tflite_event->elapsed_time = event.elapsed_time_us;
363+
tflite_event->event_type = event.event_type;
364+
tflite_event->event_metadata = event.event_metadata1;
365+
tflite_event->extra_event_metadata = event.event_metadata2;
366+
tflite_event->begin_mem_usage = event.begin_mem_usage;
367+
tflite_event->end_mem_usage = event.end_mem_usage;
368+
tflite_ptr_events.push_back(tflite_event.get());
369+
tflite_events.push_back(std::move(tflite_event));
342370
}
371+
run_summarizer_->ProcessProfiles(tflite_ptr_events, *interpreter_);
343372
profiler_.Reset();
344373
profiler_.StartProfiling();
345374
}
@@ -394,8 +423,12 @@ class BenchmarkLiteRtModel : public BenchmarkModel {
394423
std::unique_ptr<std::vector<litert::TensorBuffer>> input_buffers_;
395424
std::unique_ptr<std::vector<litert::TensorBuffer>> output_buffers_;
396425
litert::Profiler profiler_;
397-
BenchmarkLoggingListener log_output_;
426+
std::unique_ptr<BenchmarkLoggingListener> log_output_;
398427
std::unique_ptr<ModelRuntimeInfoListener> model_runtime_info_listener_;
428+
429+
// TFLite Interpreter is needed for run_summarizer_
430+
::tflite::Interpreter* interpreter_ = nullptr;
431+
std::unique_ptr<tflite::profiling::ProfileSummarizer> run_summarizer_;
399432
};
400433

401434
} // namespace benchmark

0 commit comments

Comments
 (0)