Skip to content

Commit 779733a

Browse files
committed
support loading from model bytes in common code
1 parent ef951d1 commit 779733a

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

mobile/examples/model_tester/common/include/model_runner.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
#pragma once
22

3+
#include <cstddef>
34
#include <cstdint>
45

56
#include <chrono>
67
#include <optional>
8+
#include <span>
79
#include <string>
810
#include <unordered_map>
11+
#include <variant>
912
#include <vector>
1013

1114
namespace model_runner {
@@ -14,8 +17,10 @@ using Clock = std::chrono::steady_clock;
1417
using Duration = Clock::duration;
1518

1619
struct RunConfig {
17-
// Path to the model to run.
18-
std::string model_path{};
20+
using ModelPathOrBytes = std::variant<std::string, std::span<const std::byte>>;
21+
22+
// Path or bytes of the model to run.
23+
ModelPathOrBytes model_path_or_bytes{};
1924

2025
// Whether to run a warmup iteration before running the measured (timed) iterations.
2126
bool run_warmup_iteration{true};

mobile/examples/model_tester/common/model_runner.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,19 @@ RunResult Run(const RunConfig& run_config) {
182182
}
183183

184184
Timer timer{};
185-
auto session = Ort::Session{env, run_config.model_path.c_str(), session_options};
186-
run_result.load_duration = timer.Elapsed();
185+
186+
auto session = Ort::Session{nullptr};
187+
if (std::holds_alternative<std::string>(run_config.model_path_or_bytes)) {
188+
const auto& model_path = std::get<std::string>(run_config.model_path_or_bytes);
189+
timer.Reset();
190+
session = Ort::Session{env, model_path.c_str(), session_options};
191+
run_result.load_duration = timer.Elapsed();
192+
} else {
193+
const auto& model_bytes = std::get<std::span<const std::byte>>(run_config.model_path_or_bytes);
194+
timer.Reset();
195+
session = Ort::Session{env, model_bytes.data(), model_bytes.size(), session_options};
196+
run_result.load_duration = timer.Elapsed();
197+
}
187198

188199
auto input_names = GetModelInputNames(session);
189200
auto input_name_cstrs = GetCstrs(input_names);

mobile/examples/model_tester/ios/ModelRunner/model_runner_objc_wrapper.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ @implementation ModelRunnerRunConfig {
1515
}
1616

1717
- (void)setModelPath:(nonnull NSString*)modelPath {
18-
_runConfig.model_path = modelPath.UTF8String;
18+
_runConfig.model_path_or_bytes = std::string{modelPath.UTF8String};
1919
}
2020

2121
- (void)setNumIterations:(NSUInteger)numIterations {

0 commit comments

Comments
 (0)