Skip to content

Commit 9158713

Browse files
LukeBoyercopybara-github
authored andcommitted
Allow repeatable flags for test filtering in ats for convenience.
LiteRT-PiperOrigin-RevId: 826209225
1 parent 30faa6d commit 9158713

File tree

17 files changed

+254
-103
lines changed

17 files changed

+254
-103
lines changed

litert/ats/BUILD

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ litert_test(
5151
"//litert/test/generators",
5252
"//litert/test/generators:common",
5353
"//tflite/schema:schema_fbs",
54+
"@com_google_absl//absl/flags:flag",
5455
"@com_google_absl//absl/flags:parse",
5556
"@com_google_absl//absl/log:absl_check",
57+
"@com_google_absl//absl/strings:string_view",
5658
"@com_google_googletest//:gtest",
5759
],
5860
)
@@ -62,17 +64,25 @@ cc_library(
6264
testonly = True,
6365
srcs = ["configure.cc"],
6466
hdrs = ["configure.h"],
65-
copts = commandline_flag_copts(),
67+
copts = commandline_flag_copts() + [
68+
"-DINCLUDE_QUALCOMM_COMPILE_FLAGS",
69+
"-DINCLUDE_QUALCOMM_RUNTIME_FLAGS",
70+
"-DINCLUDE_MEDIATEK_COMPILE_FLAGS",
71+
"-DINCLUDE_MEDIATEK_RUNTIME_FLAGS",
72+
],
6673
deps = [
6774
":common",
6875
"//litert/c:litert_common",
6976
"//litert/c/internal:litert_logging",
7077
"//litert/cc:litert_expected",
7178
"//litert/cc:litert_macros",
79+
"//litert/cc:litert_options",
7280
"//litert/cc/internal:litert_rng",
7381
"//litert/compiler/plugin:compiler_plugin",
7482
"//litert/core:filesystem_testonly",
7583
"//litert/core/model:model_serialize",
84+
"//litert/tools/flags/vendors:mediatek_flags",
85+
"//litert/tools/flags/vendors:qualcomm_flags",
7686
"@com_google_absl//absl/container:flat_hash_map",
7787
"@com_google_absl//absl/flags:flag",
7888
"@com_google_absl//absl/strings",
@@ -111,6 +121,7 @@ cc_library(
111121
"//litert/cc:litert_expected",
112122
"//litert/cc/internal:litert_detail",
113123
"//litert/cc/internal:litert_rng",
124+
"//litert/core:filesystem",
114125
"//litert/test/generators",
115126
],
116127
)
@@ -120,8 +131,10 @@ cc_test(
120131
srcs = ["executor_test.cc"],
121132
deps = [
122133
":executor",
134+
"//litert/c:litert_common",
123135
"//litert/c:litert_op_code",
124136
"//litert/cc:litert_buffer_ref",
137+
"//litert/cc:litert_options",
125138
"//litert/core/model",
126139
"//litert/test:matchers",
127140
"//litert/test:simple_buffer",
@@ -186,8 +199,10 @@ litert_device_test(
186199
defines = ["_TEST_NPU"],
187200
deps = [
188201
":executor",
202+
"//litert/c:litert_common",
189203
"//litert/c:litert_op_code",
190204
"//litert/cc:litert_buffer_ref",
205+
"//litert/cc:litert_options",
191206
"//litert/core/model",
192207
"//litert/test:matchers",
193208
"//litert/test:simple_buffer",
@@ -367,7 +382,6 @@ litert_define_ats(
367382
# "^(?:(?!npu_ats_87).)*$$",
368383
],
369384
extra_flags = [
370-
"--f16_range_for_f32=true",
371385
"--data_seed=42",
372386
],
373387
jit_suffix = "",

litert/ats/ats.cc

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
#include <cstddef>
1616
#include <cstdint>
1717
#include <iostream>
18+
#include <string>
19+
#include <vector>
1820

1921
#include <gtest/gtest.h>
22+
#include "absl/flags/flag.h" // from @com_google_absl
2023
#include "absl/flags/parse.h" // from @com_google_absl
2124
#include "absl/log/absl_check.h" // from @com_google_absl
25+
#include "absl/strings/string_view.h" // from @com_google_absl
2226
#include "litert/ats/compile_fixture.h"
2327
#include "litert/ats/configure.h"
2428
#include "litert/ats/inference_fixture.h"
@@ -119,7 +123,29 @@ int Ats() {
119123
} // namespace litert::testing
120124

121125
int main(int argc, char** argv) {
122-
::testing::InitGoogleTest(&argc, argv);
123-
absl::ParseCommandLine(argc, argv);
126+
// Shim to support repeatable flags which absl does not.
127+
std::vector<char*> absl_flags;
128+
static constexpr absl::string_view kDoRegisterPrefix = "--do_register=";
129+
static constexpr absl::string_view kDontRegisterPrefix = "--dont_register=";
130+
std::vector<std::string> do_register;
131+
std::vector<std::string> dont_register;
132+
for (int i = 0; i < argc; ++i) {
133+
if (::litert::StartsWith(argv[i], kDoRegisterPrefix)) {
134+
do_register.push_back(std::string(
135+
absl::string_view(argv[i]).substr(kDoRegisterPrefix.size())));
136+
} else if (::litert::StartsWith(argv[i], kDontRegisterPrefix)) {
137+
dont_register.push_back(std::string(
138+
absl::string_view(argv[i]).substr(kDontRegisterPrefix.size())));
139+
} else {
140+
absl_flags.push_back(argv[i]);
141+
}
142+
}
143+
144+
absl::SetFlag(&FLAGS_do_register, do_register);
145+
absl::SetFlag(&FLAGS_dont_register, dont_register);
146+
147+
int absl_argc = absl_flags.size();
148+
::testing::InitGoogleTest(&absl_argc, absl_flags.data());
149+
absl::ParseCommandLine(absl_argc, absl_flags.data());
124150
return litert::testing::Ats();
125151
}

litert/ats/check_ats.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ Expected<void> CheckAts() {
179179

180180
LITERT_ASSIGN_OR_RETURN(auto exec,
181181
NpuCompiledModelExecutor::Create(
182-
*model, npu_inference_options.DispatchDir()));
182+
*model, npu_inference_options.TargetOptions(),
183+
npu_inference_options.DispatchDir()));
183184
const auto& subgraph = *model->Subgraphs()[0];
184185
LITERT_ASSIGN_OR_RETURN(
185186
auto inputs, SimpleBuffer::LikeSignature(subgraph.Inputs().begin(),

litert/ats/common.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@ struct TestNames {
4848
}
4949

5050
// Create with an explicit desc.
51-
static TestNames Create(size_t test_id, absl::string_view family,
52-
absl::string_view logic, absl::string_view test,
51+
static TestNames Create(size_t test_id, absl::string_view fixture,
52+
absl::string_view source, absl::string_view test,
53+
absl::string_view report_id,
5354
absl::string_view desc = "") {
54-
auto suite = MakeSuite(test_id, family, logic);
55-
return {suite, std::string(logic), std::string(desc), std::string(test)};
55+
auto suite = MakeSuite(test_id, fixture, source);
56+
return {suite, std::string(test), std::string(desc),
57+
std::string(report_id)};
5658
}
5759

5860
private:
@@ -98,7 +100,7 @@ enum class CompilationStatus {
98100
// Timing related types.
99101
using Clock = std::chrono::steady_clock;
100102
using TimePoint = Clock::time_point;
101-
using Nanoseconds = uint64_t;
103+
using Microseconds = uint64_t;
102104

103105
// Which backend to use as the "actual".
104106
enum class ExecutionBackend { kCpu, kGpu, kNpu };
@@ -175,7 +177,7 @@ void AbslStringify(Sink& sink, const CompilationStatus& status) {
175177
}
176178

177179
template <typename Sink>
178-
void AbslStringify(Sink& sink, const Nanoseconds& ns) {
180+
void AbslStringify(Sink& sink, const Microseconds& ns) {
179181
absl::Format(&sink, "%e", ns);
180182
}
181183

litert/ats/compile_capture.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,25 @@
2929
namespace litert::testing {
3030

3131
// Information about the time taken to compile the model.
32-
class CompilationTime : public Printable<Nanoseconds> {
32+
class CompilationTime : public Printable<Microseconds> {
3333
public:
3434
// Start timing.
3535
TimePoint Start() const { return Clock::now(); }
3636

3737
// Stop timing and record the latency.
3838
void Stop(const TimePoint& start) {
39-
std::chrono::duration<Nanoseconds, std::nano> nano = Clock::now() - start;
39+
std::chrono::duration<Microseconds, std::nano> nano = Clock::now() - start;
4040
nanos_ = nano.count();
4141
}
4242

43-
Nanoseconds Nanos() const { return nanos_; }
43+
Microseconds Nanos() const { return nanos_; }
4444

4545
CompilationTime() : Printable("CompilationTime", "compile_time(ns)") {}
4646

4747
private:
4848
Fields GetFields() const override { return Fields{nanos_}; }
4949

50-
Nanoseconds nanos_ = std::numeric_limits<Nanoseconds>::max();
50+
Microseconds nanos_ = std::numeric_limits<Microseconds>::max();
5151
};
5252

5353
// Type to hold all of the capturable information related compilation test

litert/ats/compile_capture_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ TEST(AtsCompileCaptureTest, Basic) {
3333
e.model.name = "FOO";
3434

3535
ASSERT_NE(e.compilation_time.Nanos(),
36-
std::numeric_limits<Nanoseconds>::max());
36+
std::numeric_limits<Microseconds>::max());
3737

3838
std::ostringstream s;
3939
cap.Print(s);

litert/ats/configure.cc

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
#include "litert/c/litert_common.h"
3434
#include "litert/cc/litert_expected.h"
3535
#include "litert/cc/litert_macros.h"
36+
#include "litert/cc/litert_options.h"
3637
#include "litert/compiler/plugin/compiler_plugin.h"
38+
#include "litert/tools/flags/vendors/mediatek_flags.h" // IWYU pragma: export
39+
#include "litert/tools/flags/vendors/qualcomm_flags.h" // IWYU pragma: export
3740

3841
ABSL_FLAG(std::optional<int>, data_seed, std::nullopt,
3942
"Seed for the buffer data generation.");
@@ -58,11 +61,11 @@ ABSL_FLAG(std::string, plugin_dir, "",
5861
"relevant for NPU.");
5962

6063
ABSL_FLAG(
61-
std::string, dont_register, "^$",
64+
std::vector<std::string>, dont_register, std::vector<std::string>{},
6265
"Regex for test selection. This is a negative search match, if the pattern "
6366
"can be found anywhere in the test name, it will be skipped.");
6467

65-
ABSL_FLAG(std::string, do_register, ".*",
68+
ABSL_FLAG(std::vector<std::string>, do_register, std::vector<std::string>{},
6669
"Regex for test selection. This is a positive search match, if the "
6770
"pattern can be found anywhere in the test name, it will be run. "
6871
"This has lower priority over the dont_register filter.");
@@ -113,6 +116,9 @@ namespace litert::testing {
113116

114117
namespace {
115118

119+
using ::litert::mediatek::MediatekOptionsFromFlags;
120+
using ::litert::qualcomm::QualcommOptionsFromFlags;
121+
116122
Expected<AtsConf::SeedMap> ParseParamSeedMap() {
117123
const auto seed_flags = absl::GetFlag(FLAGS_seeds);
118124
AtsConf::SeedMap seeds;
@@ -143,15 +149,34 @@ Expected<ExecutionBackend> ParseBackend() {
143149
}
144150
}
145151

152+
Expected<Options> ParseOptions(ExecutionBackend backend) {
153+
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
154+
if (backend == ExecutionBackend::kNpu) {
155+
if (auto qnn_opts = QualcommOptionsFromFlags()) {
156+
options.AddOpaqueOptions(std::move(*qnn_opts));
157+
}
158+
if (auto mediatek_opts = MediatekOptionsFromFlags()) {
159+
options.AddOpaqueOptions(std::move(*mediatek_opts));
160+
}
161+
options.SetHardwareAccelerators(kLiteRtHwAcceleratorNpu);
162+
} else if (backend == ExecutionBackend::kCpu) {
163+
options.SetHardwareAccelerators(kLiteRtHwAcceleratorCpu);
164+
} else if (backend == ExecutionBackend::kGpu) {
165+
options.SetHardwareAccelerators(kLiteRtHwAcceleratorGpu);
166+
}
167+
return options;
168+
}
169+
146170
Expected<std::optional<internal::CompilerPlugin>> ParsePlugin(
147171
absl::string_view plugin_dir, absl::string_view soc_manufacturer,
148-
bool compile_mode) {
172+
bool compile_mode, const Options& litert_options) {
149173
using R = std::optional<internal::CompilerPlugin>;
150174
if (!compile_mode) {
151175
return R(std::nullopt);
152176
}
153177
LITERT_ASSIGN_OR_RETURN(auto plugin, internal::CompilerPlugin::FindPlugin(
154-
soc_manufacturer, {plugin_dir}));
178+
soc_manufacturer, {plugin_dir},
179+
nullptr, litert_options.Get()));
155180
return R(std::move(plugin));
156181
}
157182

@@ -166,12 +191,15 @@ void Setup(const AtsConf& options) {
166191
Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup() {
167192
LITERT_ASSIGN_OR_RETURN(auto seeds, ParseParamSeedMap());
168193
LITERT_ASSIGN_OR_RETURN(auto backend, ParseBackend());
169-
std::regex neg_re(absl::GetFlag(FLAGS_dont_register),
170-
std::regex_constants::ECMAScript);
171-
std::regex pos_re(absl::GetFlag(FLAGS_do_register),
172-
std::regex_constants::ECMAScript);
194+
std::vector<std::regex> neg_re;
195+
for (const auto& re : absl::GetFlag(FLAGS_dont_register)) {
196+
neg_re.push_back(std::regex(re, std::regex_constants::ECMAScript));
197+
}
198+
std::vector<std::regex> pos_re;
199+
for (const auto& re : absl::GetFlag(FLAGS_do_register)) {
200+
pos_re.push_back(std::regex(re, std::regex_constants::ECMAScript));
201+
}
173202
auto extra_models = absl::GetFlag(FLAGS_extra_models);
174-
auto f16_range_for_f32 = absl::GetFlag(FLAGS_f16_range_for_f32);
175203
auto data_seed = absl::GetFlag(FLAGS_data_seed);
176204
auto dispatch_dir = absl::GetFlag(FLAGS_dispatch_dir);
177205
auto plugin_dir = absl::GetFlag(FLAGS_plugin_dir);
@@ -190,15 +218,19 @@ Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup() {
190218
auto limit = absl::GetFlag(FLAGS_limit);
191219
auto soc_manufacturer = absl::GetFlag(FLAGS_soc_manufacturer);
192220
auto soc_model = absl::GetFlag(FLAGS_soc_model);
221+
LITERT_ASSIGN_OR_RETURN(auto target_options, ParseOptions(backend));
222+
LITERT_ASSIGN_OR_RETURN(auto reference_options, Options::Create());
223+
reference_options.SetHardwareAccelerators(kLiteRtHwAcceleratorCpu);
193224
LITERT_ASSIGN_OR_RETURN(
194-
auto plugin, ParsePlugin(plugin_dir, soc_manufacturer, compile_mode));
225+
auto plugin,
226+
ParsePlugin(plugin_dir, soc_manufacturer, compile_mode, target_options));
195227
AtsConf res(std::move(seeds), backend, quiet, dispatch_dir, plugin_dir,
196228
std::move(neg_re), std::move(pos_re), std::move(extra_models),
197-
f16_range_for_f32, data_seed, iters_per_test,
198-
std::move(max_ms_per_test_opt), fail_on_timeout, dump_report,
199-
std::move(csv), compile_mode, std::move(models_out), limit,
200-
std::move(plugin), std::move(soc_manufacturer),
201-
std::move(soc_model));
229+
data_seed, iters_per_test, std::move(max_ms_per_test_opt),
230+
fail_on_timeout, dump_report, std::move(csv), compile_mode,
231+
std::move(models_out), limit, std::move(plugin),
232+
std::move(soc_manufacturer), std::move(soc_model),
233+
std::move(target_options), std::move(reference_options));
202234
Setup(res);
203235
return res;
204236
}
@@ -213,7 +245,15 @@ int AtsConf::GetSeedForParams(absl::string_view name) const {
213245
}
214246

215247
bool AtsConf::ShouldRegister(const std::string& name) const {
216-
return std::regex_search(name, pos_re_) && !std::regex_search(name, neg_re_);
248+
const bool include =
249+
pos_re_.empty() ||
250+
std::any_of(pos_re_.begin(), pos_re_.end(), [&name](const auto& re) {
251+
return std::regex_search(name, re);
252+
});
253+
const bool exclude = std::any_of(
254+
neg_re_.begin(), neg_re_.end(),
255+
[&name](const auto& re) { return std::regex_search(name, re); });
256+
return include && !exclude;
217257
};
218258

219259
bool AtsConf::ShouldRegister(absl::string_view name) const {

0 commit comments

Comments
 (0)