Skip to content

Commit f265cf3

Browse files
committed
add model_tester files
1 parent afba506 commit f265cf3

File tree

17 files changed

+1197
-0
lines changed

17 files changed

+1197
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
5+
#include <chrono>
6+
#include <optional>
7+
#include <string>
8+
#include <vector>
9+
10+
namespace model_runner {
11+
12+
struct RunConfig {
13+
std::string model_path{};
14+
15+
size_t num_warmup_iterations{};
16+
size_t num_iterations{};
17+
18+
std::optional<int> log_level{};
19+
};
20+
21+
using Clock = std::chrono::steady_clock;
22+
using Duration = Clock::duration;
23+
24+
struct RunResult {
25+
Duration load_duration;
26+
std::vector<Duration> run_durations;
27+
};
28+
29+
RunResult Run(const RunConfig& run_config);
30+
31+
std::string GetRunSummary(const RunConfig& run_config, const RunResult& run_result);
32+
33+
} // namespace model_runner
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
#include "model_runner.h"
2+
3+
#include <cstddef>
4+
5+
#include <algorithm>
6+
#include <chrono>
7+
#include <filesystem>
8+
#include <format>
9+
#include <iterator>
10+
#include <numeric>
11+
#include <span>
12+
13+
#include "onnxruntime_cxx_api.h"
14+
15+
namespace model_runner {
16+
17+
namespace {
18+
19+
size_t GetDataTypeSizeInBytes(ONNXTensorElementDataType data_type) {
20+
switch (data_type) {
21+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN:
22+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ:
23+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2:
24+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ:
25+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
26+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
27+
return 1;
28+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
29+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
30+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
31+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
32+
return 2;
33+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
34+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
35+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
36+
return 4;
37+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
38+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
39+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
40+
return 8;
41+
default:
42+
throw std::invalid_argument(std::format("unsupported tensor data type: {}", static_cast<int>(data_type)));
43+
}
44+
}
45+
46+
void FillTensorWithZeroes(Ort::Value& value) {
47+
const auto tensor_info = value.GetTensorTypeAndShapeInfo();
48+
const auto data_type = tensor_info.GetElementType();
49+
const auto num_elements = tensor_info.GetElementCount();
50+
const auto data_type_size_in_bytes = GetDataTypeSizeInBytes(data_type);
51+
const auto data_size_in_bytes = num_elements * data_type_size_in_bytes;
52+
53+
std::byte* data = static_cast<std::byte*>(value.GetTensorMutableRawData());
54+
std::fill(data, data + data_size_in_bytes, std::byte{0});
55+
}
56+
57+
std::vector<Ort::Value> GetModelInputValues(const Ort::Session& session) {
58+
const auto num_inputs = session.GetInputCount();
59+
60+
std::vector<Ort::Value> input_values{};
61+
input_values.reserve(num_inputs);
62+
63+
Ort::AllocatorWithDefaultOptions allocator{};
64+
65+
for (size_t i = 0; i < num_inputs; ++i) {
66+
auto type_info = session.GetInputTypeInfo(i);
67+
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
68+
69+
auto tensor_shape = tensor_info.GetShape();
70+
// make this a static shape
71+
for (auto& dim : tensor_shape) {
72+
if (dim == -1) {
73+
dim = 1;
74+
}
75+
}
76+
77+
const auto tensor_data_type = tensor_info.GetElementType();
78+
79+
auto value = Ort::Value::CreateTensor(allocator, tensor_shape.data(), tensor_shape.size(), tensor_data_type);
80+
81+
FillTensorWithZeroes(value);
82+
83+
input_values.emplace_back(std::move(value));
84+
}
85+
86+
return input_values;
87+
}
88+
89+
std::vector<std::string> GetModelInputOrOutputNames(const Ort::Session& session, bool is_input) {
90+
const auto num_inputs_or_outputs = is_input ? session.GetInputCount() : session.GetOutputCount();
91+
92+
std::vector<std::string> names{};
93+
names.reserve(num_inputs_or_outputs);
94+
95+
auto allocator = Ort::AllocatorWithDefaultOptions{};
96+
for (size_t i = 0; i < num_inputs_or_outputs; ++i) {
97+
auto name = is_input ? session.GetInputNameAllocated(i, allocator)
98+
: session.GetOutputNameAllocated(i, allocator);
99+
names.emplace_back(name.get());
100+
}
101+
102+
return names;
103+
}
104+
105+
std::vector<std::string> GetModelInputNames(const Ort::Session& session) {
106+
return GetModelInputOrOutputNames(session, /* is_input */ true);
107+
}
108+
109+
std::vector<std::string> GetModelOutputNames(const Ort::Session& session) {
110+
return GetModelInputOrOutputNames(session, /* is_input */ false);
111+
}
112+
113+
std::vector<const char*> GetCstrs(std::span<const std::string> strs) {
114+
std::vector<const char*> cstrs{};
115+
cstrs.reserve(strs.size());
116+
std::transform(strs.begin(), strs.end(), std::back_inserter(cstrs),
117+
[](const std::string& str) { return str.c_str(); });
118+
return cstrs;
119+
}
120+
121+
class Timer {
122+
public:
123+
Timer() { Reset(); }
124+
125+
void Reset() { start_ = Clock::now(); }
126+
127+
Duration Elapsed() const { return Clock::now() - start_; }
128+
129+
private:
130+
Clock::time_point start_;
131+
};
132+
133+
struct RunResultStats {
134+
using DurationFp = std::chrono::duration<float, Duration::period>;
135+
136+
size_t n;
137+
DurationFp average;
138+
Duration min, max;
139+
Duration p50, p90, p99;
140+
};
141+
142+
RunResultStats ComputeRunResultStats(const RunResult& run_result) {
143+
using DurationFp = RunResultStats::DurationFp;
144+
145+
const auto& run_durations = run_result.run_durations;
146+
147+
RunResultStats stats{};
148+
const auto n = run_durations.size();
149+
stats.n = n;
150+
if (n > 0) {
151+
const auto total_run_duration = std::accumulate(run_durations.begin(), run_durations.end(),
152+
DurationFp{0.0f});
153+
stats.average = DurationFp{total_run_duration.count() / n};
154+
155+
auto sorted_run_durations = run_durations;
156+
std::sort(sorted_run_durations.begin(), sorted_run_durations.end());
157+
stats.min = sorted_run_durations.front();
158+
stats.max = sorted_run_durations.back();
159+
stats.p50 = sorted_run_durations[static_cast<size_t>(0.5f * n)];
160+
stats.p90 = sorted_run_durations[static_cast<size_t>(0.9f * n)];
161+
stats.p99 = sorted_run_durations[static_cast<size_t>(0.99f * n)];
162+
}
163+
164+
return stats;
165+
}
166+
167+
} // namespace
168+
169+
RunResult Run(const RunConfig& run_config) {
170+
RunResult run_result{};
171+
172+
auto env = Ort::Env{};
173+
174+
if (run_config.log_level.has_value()) {
175+
env.UpdateEnvWithCustomLogLevel(static_cast<OrtLoggingLevel>(*run_config.log_level));
176+
}
177+
178+
auto session_options = Ort::SessionOptions{};
179+
180+
Timer timer{};
181+
auto session = Ort::Session{env, run_config.model_path.c_str(), session_options};
182+
run_result.load_duration = timer.Elapsed();
183+
184+
auto input_names = GetModelInputNames(session);
185+
auto input_name_cstrs = GetCstrs(input_names);
186+
187+
auto input_values = GetModelInputValues(session);
188+
189+
auto output_names = GetModelOutputNames(session);
190+
auto output_name_cstrs = GetCstrs(output_names);
191+
192+
auto run_options = Ort::RunOptions{};
193+
194+
// warmup
195+
for (size_t i = 0; i < run_config.num_warmup_iterations; ++i) {
196+
auto outputs = session.Run(run_options,
197+
input_name_cstrs.data(), input_values.data(), input_values.size(),
198+
output_name_cstrs.data(), output_name_cstrs.size());
199+
}
200+
201+
// measure runs
202+
run_result.run_durations.reserve(run_config.num_iterations);
203+
for (size_t i = 0; i < run_config.num_iterations; ++i) {
204+
timer.Reset();
205+
auto outputs = session.Run(run_options,
206+
input_name_cstrs.data(), input_values.data(), input_values.size(),
207+
output_name_cstrs.data(), output_name_cstrs.size());
208+
run_result.run_durations.push_back(timer.Elapsed());
209+
}
210+
211+
return run_result;
212+
}
213+
214+
std::string GetRunSummary(const RunConfig& run_config, const RunResult& run_result) {
215+
auto to_display_duration = []<typename Rep, typename Period>(std::chrono::duration<Rep, Period> d) {
216+
using DisplayPeriod = std::chrono::microseconds::period;
217+
using DisplayDuration = std::chrono::duration<Rep, DisplayPeriod>;
218+
return std::chrono::duration_cast<DisplayDuration>(d);
219+
};
220+
221+
const auto model_path = std::filesystem::path{run_config.model_path};
222+
223+
const auto stats = ComputeRunResultStats(run_result);
224+
225+
const auto summary = std::format(
226+
"Model: {}\n"
227+
"Load time: {}\n"
228+
"N (number of runs): {}\n"
229+
"Latency\n"
230+
" avg: {}\n"
231+
" p50: {}\n"
232+
" p90: {}\n"
233+
" p99: {}\n"
234+
" min: {}\n"
235+
" max: {}\n",
236+
model_path.filename().string(),
237+
to_display_duration(run_result.load_duration),
238+
stats.n,
239+
to_display_duration(stats.average),
240+
to_display_duration(stats.p50),
241+
to_display_duration(stats.p90),
242+
to_display_duration(stats.p99),
243+
to_display_duration(stats.min),
244+
to_display_duration(stats.max));
245+
246+
return summary;
247+
}
248+
249+
} // namespace model_runner
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#import <Foundation/Foundation.h>
2+
#include <stdint.h>
3+
4+
NS_ASSUME_NONNULL_BEGIN
5+
6+
/**
7+
* This class is an Objective-C wrapper around the C++ model runner functionality.
8+
*/
9+
@interface ModelRunner : NSObject
10+
11+
+ (nullable NSString*)runWithModelPath:(NSString*)modelPath
12+
numIterations:(uint32_t)numIterations
13+
error:(NSError**)error;
14+
15+
@end
16+
17+
NS_ASSUME_NONNULL_END
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#import "model_runner_objc_wrapper.h"
2+
3+
#include "model_runner.h"
4+
5+
@implementation ModelRunner
6+
7+
+ (nullable NSString*)runWithModelPath:(NSString*)modelPath
8+
numIterations:(uint32_t)numIterations
9+
error:(NSError**)error {
10+
try {
11+
model_runner::RunConfig config{};
12+
config.model_path = modelPath.UTF8String;
13+
config.num_iterations = numIterations;
14+
config.num_warmup_iterations = 1;
15+
16+
auto result = model_runner::Run(config);
17+
18+
auto summary = model_runner::GetRunSummary(config, result);
19+
20+
return [NSString stringWithUTF8String:summary.c_str()];
21+
} catch (const std::exception& e) {
22+
if (error) {
23+
NSString* description = [NSString stringWithCString:e.what()
24+
encoding:NSUTF8StringEncoding];
25+
26+
*error = [NSError errorWithDomain:@"ModelRunner"
27+
code:0
28+
userInfo:@{NSLocalizedDescriptionKey : description}];
29+
}
30+
return nil;
31+
}
32+
}
33+
34+
@end
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
//
2+
// Use this file to import your target's public headers that you would like to expose to Swift.
3+
//
4+
5+
#import "model_runner_objc_wrapper.h"

0 commit comments

Comments
 (0)