Skip to content

Commit 17b163e

Browse files
LukeBoyercopybara-github
authored andcommitted
Allow repeatable flags for test filtering in ats for convenience.
LiteRT-PiperOrigin-RevId: 826209225
1 parent 30faa6d commit 17b163e

File tree

17 files changed

+255
-103
lines changed

17 files changed

+255
-103
lines changed

litert/ats/BUILD

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ litert_test(
5151
"//litert/test/generators",
5252
"//litert/test/generators:common",
5353
"//tflite/schema:schema_fbs",
54+
"@com_google_absl//absl/flags:flag",
5455
"@com_google_absl//absl/flags:parse",
5556
"@com_google_absl//absl/log:absl_check",
57+
"@com_google_absl//absl/strings:string_view",
5658
"@com_google_googletest//:gtest",
5759
],
5860
)
@@ -62,17 +64,25 @@ cc_library(
6264
testonly = True,
6365
srcs = ["configure.cc"],
6466
hdrs = ["configure.h"],
65-
copts = commandline_flag_copts(),
67+
copts = commandline_flag_copts() + [
68+
"-DINCLUDE_QUALCOMM_COMPILE_FLAGS",
69+
"-DINCLUDE_QUALCOMM_RUNTIME_FLAGS",
70+
"-DINCLUDE_MEDIATEK_COMPILE_FLAGS",
71+
"-DINCLUDE_MEDIATEK_RUNTIME_FLAGS",
72+
],
6673
deps = [
6774
":common",
6875
"//litert/c:litert_common",
6976
"//litert/c/internal:litert_logging",
7077
"//litert/cc:litert_expected",
7178
"//litert/cc:litert_macros",
79+
"//litert/cc:litert_options",
7280
"//litert/cc/internal:litert_rng",
7381
"//litert/compiler/plugin:compiler_plugin",
7482
"//litert/core:filesystem_testonly",
7583
"//litert/core/model:model_serialize",
84+
"//litert/tools/flags/vendors:mediatek_flags",
85+
"//litert/tools/flags/vendors:qualcomm_flags",
7686
"@com_google_absl//absl/container:flat_hash_map",
7787
"@com_google_absl//absl/flags:flag",
7888
"@com_google_absl//absl/strings",
@@ -111,6 +121,7 @@ cc_library(
111121
"//litert/cc:litert_expected",
112122
"//litert/cc/internal:litert_detail",
113123
"//litert/cc/internal:litert_rng",
124+
"//litert/core:filesystem",
114125
"//litert/test/generators",
115126
],
116127
)
@@ -120,8 +131,10 @@ cc_test(
120131
srcs = ["executor_test.cc"],
121132
deps = [
122133
":executor",
134+
"//litert/c:litert_common",
123135
"//litert/c:litert_op_code",
124136
"//litert/cc:litert_buffer_ref",
137+
"//litert/cc:litert_options",
125138
"//litert/core/model",
126139
"//litert/test:matchers",
127140
"//litert/test:simple_buffer",
@@ -186,8 +199,10 @@ litert_device_test(
186199
defines = ["_TEST_NPU"],
187200
deps = [
188201
":executor",
202+
"//litert/c:litert_common",
189203
"//litert/c:litert_op_code",
190204
"//litert/cc:litert_buffer_ref",
205+
"//litert/cc:litert_options",
191206
"//litert/core/model",
192207
"//litert/test:matchers",
193208
"//litert/test:simple_buffer",
@@ -367,7 +382,6 @@ litert_define_ats(
367382
# "^(?:(?!npu_ats_87).)*$$",
368383
],
369384
extra_flags = [
370-
"--f16_range_for_f32=true",
371385
"--data_seed=42",
372386
],
373387
jit_suffix = "",

litert/ats/ats.cc

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
#include <cstddef>
1616
#include <cstdint>
1717
#include <iostream>
18+
#include <string>
19+
#include <vector>
1820

1921
#include <gtest/gtest.h>
22+
#include "absl/flags/flag.h" // from @com_google_absl
2023
#include "absl/flags/parse.h" // from @com_google_absl
2124
#include "absl/log/absl_check.h" // from @com_google_absl
25+
#include "absl/strings/string_view.h" // from @com_google_absl
2226
#include "litert/ats/compile_fixture.h"
2327
#include "litert/ats/configure.h"
2428
#include "litert/ats/inference_fixture.h"
@@ -119,7 +123,29 @@ int Ats() {
119123
} // namespace litert::testing
120124

121125
int main(int argc, char** argv) {
122-
::testing::InitGoogleTest(&argc, argv);
123-
absl::ParseCommandLine(argc, argv);
126+
// Shim to support repeatable flags which absl does not.
127+
std::vector<char*> absl_flags;
128+
static constexpr absl::string_view kDoRegisterPrefix = "--do_register=";
129+
static constexpr absl::string_view kDontRegisterPrefix = "--dont_register=";
130+
std::vector<std::string> do_register;
131+
std::vector<std::string> dont_register;
132+
for (int i = 0; i < argc; ++i) {
133+
if (::litert::StartsWith(argv[i], kDoRegisterPrefix)) {
134+
do_register.push_back(std::string(
135+
absl::string_view(argv[i]).substr(kDoRegisterPrefix.size())));
136+
} else if (::litert::StartsWith(argv[i], kDontRegisterPrefix)) {
137+
dont_register.push_back(std::string(
138+
absl::string_view(argv[i]).substr(kDontRegisterPrefix.size())));
139+
} else {
140+
absl_flags.push_back(argv[i]);
141+
}
142+
}
143+
144+
absl::SetFlag(&FLAGS_do_register, do_register);
145+
absl::SetFlag(&FLAGS_dont_register, dont_register);
146+
147+
int absl_argc = absl_flags.size();
148+
::testing::InitGoogleTest(&absl_argc, absl_flags.data());
149+
absl::ParseCommandLine(absl_argc, absl_flags.data());
124150
return litert::testing::Ats();
125151
}

litert/ats/check_ats.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ Expected<void> CheckAts() {
179179

180180
LITERT_ASSIGN_OR_RETURN(auto exec,
181181
NpuCompiledModelExecutor::Create(
182-
*model, npu_inference_options.DispatchDir()));
182+
*model, npu_inference_options.TargetOptions(),
183+
npu_inference_options.DispatchDir()));
183184
const auto& subgraph = *model->Subgraphs()[0];
184185
LITERT_ASSIGN_OR_RETURN(
185186
auto inputs, SimpleBuffer::LikeSignature(subgraph.Inputs().begin(),

litert/ats/common.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@ struct TestNames {
4848
}
4949

5050
// Create with an explicit desc.
51-
static TestNames Create(size_t test_id, absl::string_view family,
52-
absl::string_view logic, absl::string_view test,
51+
static TestNames Create(size_t test_id, absl::string_view fixture,
52+
absl::string_view source, absl::string_view test,
53+
absl::string_view report_id,
5354
absl::string_view desc = "") {
54-
auto suite = MakeSuite(test_id, family, logic);
55-
return {suite, std::string(logic), std::string(desc), std::string(test)};
55+
auto suite = MakeSuite(test_id, fixture, source);
56+
return {suite, std::string(test), std::string(desc),
57+
std::string(report_id)};
5658
}
5759

5860
private:
@@ -98,7 +100,7 @@ enum class CompilationStatus {
98100
// Timing related types.
99101
using Clock = std::chrono::steady_clock;
100102
using TimePoint = Clock::time_point;
101-
using Nanoseconds = uint64_t;
103+
using Microseconds = uint64_t;
102104

103105
// Which backend to use as the "actual".
104106
enum class ExecutionBackend { kCpu, kGpu, kNpu };
@@ -175,7 +177,7 @@ void AbslStringify(Sink& sink, const CompilationStatus& status) {
175177
}
176178

177179
template <typename Sink>
178-
void AbslStringify(Sink& sink, const Nanoseconds& ns) {
180+
void AbslStringify(Sink& sink, const Microseconds& ns) {
179181
absl::Format(&sink, "%e", ns);
180182
}
181183

litert/ats/compile_capture.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,25 @@
2929
namespace litert::testing {
3030

3131
// Information about the time taken to compile the model.
32-
class CompilationTime : public Printable<Nanoseconds> {
32+
class CompilationTime : public Printable<Microseconds> {
3333
public:
3434
// Start timing.
3535
TimePoint Start() const { return Clock::now(); }
3636

3737
// Stop timing and record the latency.
3838
void Stop(const TimePoint& start) {
39-
std::chrono::duration<Nanoseconds, std::nano> nano = Clock::now() - start;
39+
std::chrono::duration<Microseconds, std::nano> nano = Clock::now() - start;
4040
nanos_ = nano.count();
4141
}
4242

43-
Nanoseconds Nanos() const { return nanos_; }
43+
Microseconds Nanos() const { return nanos_; }
4444

4545
CompilationTime() : Printable("CompilationTime", "compile_time(ns)") {}
4646

4747
private:
4848
Fields GetFields() const override { return Fields{nanos_}; }
4949

50-
Nanoseconds nanos_ = std::numeric_limits<Nanoseconds>::max();
50+
Microseconds nanos_ = std::numeric_limits<Microseconds>::max();
5151
};
5252

5353
// Type to hold all of the capturable information related compilation test

litert/ats/compile_capture_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ TEST(AtsCompileCaptureTest, Basic) {
3333
e.model.name = "FOO";
3434

3535
ASSERT_NE(e.compilation_time.Nanos(),
36-
std::numeric_limits<Nanoseconds>::max());
36+
std::numeric_limits<Microseconds>::max());
3737

3838
std::ostringstream s;
3939
cap.Print(s);

litert/ats/configure.cc

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "litert/ats/configure.h"
1616

17+
#include <algorithm>
1718
#include <chrono> // NOLINT
1819
#include <cstddef>
1920
#include <cstdint>
@@ -33,7 +34,10 @@
3334
#include "litert/c/litert_common.h"
3435
#include "litert/cc/litert_expected.h"
3536
#include "litert/cc/litert_macros.h"
37+
#include "litert/cc/litert_options.h"
3638
#include "litert/compiler/plugin/compiler_plugin.h"
39+
#include "litert/tools/flags/vendors/mediatek_flags.h" // IWYU pragma: export
40+
#include "litert/tools/flags/vendors/qualcomm_flags.h" // IWYU pragma: export
3741

3842
ABSL_FLAG(std::optional<int>, data_seed, std::nullopt,
3943
"Seed for the buffer data generation.");
@@ -58,11 +62,11 @@ ABSL_FLAG(std::string, plugin_dir, "",
5862
"relevant for NPU.");
5963

6064
ABSL_FLAG(
61-
std::string, dont_register, "^$",
65+
std::vector<std::string>, dont_register, std::vector<std::string>{},
6266
"Regex for test selection. This is a negative search match, if the pattern "
6367
"can be found anywhere in the test name, it will be skipped.");
6468

65-
ABSL_FLAG(std::string, do_register, ".*",
69+
ABSL_FLAG(std::vector<std::string>, do_register, std::vector<std::string>{},
6670
"Regex for test selection. This is a positive search match, if the "
6771
"pattern can be found anywhere in the test name, it will be run. "
6872
"This has lower priority over the dont_register filter.");
@@ -113,6 +117,9 @@ namespace litert::testing {
113117

114118
namespace {
115119

120+
using ::litert::mediatek::MediatekOptionsFromFlags;
121+
using ::litert::qualcomm::QualcommOptionsFromFlags;
122+
116123
Expected<AtsConf::SeedMap> ParseParamSeedMap() {
117124
const auto seed_flags = absl::GetFlag(FLAGS_seeds);
118125
AtsConf::SeedMap seeds;
@@ -143,15 +150,34 @@ Expected<ExecutionBackend> ParseBackend() {
143150
}
144151
}
145152

153+
Expected<Options> ParseOptions(ExecutionBackend backend) {
154+
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
155+
if (backend == ExecutionBackend::kNpu) {
156+
if (auto qnn_opts = QualcommOptionsFromFlags()) {
157+
options.AddOpaqueOptions(std::move(*qnn_opts));
158+
}
159+
if (auto mediatek_opts = MediatekOptionsFromFlags()) {
160+
options.AddOpaqueOptions(std::move(*mediatek_opts));
161+
}
162+
options.SetHardwareAccelerators(kLiteRtHwAcceleratorNpu);
163+
} else if (backend == ExecutionBackend::kCpu) {
164+
options.SetHardwareAccelerators(kLiteRtHwAcceleratorCpu);
165+
} else if (backend == ExecutionBackend::kGpu) {
166+
options.SetHardwareAccelerators(kLiteRtHwAcceleratorGpu);
167+
}
168+
return options;
169+
}
170+
146171
Expected<std::optional<internal::CompilerPlugin>> ParsePlugin(
147172
absl::string_view plugin_dir, absl::string_view soc_manufacturer,
148-
bool compile_mode) {
173+
bool compile_mode, const Options& litert_options) {
149174
using R = std::optional<internal::CompilerPlugin>;
150175
if (!compile_mode) {
151176
return R(std::nullopt);
152177
}
153178
LITERT_ASSIGN_OR_RETURN(auto plugin, internal::CompilerPlugin::FindPlugin(
154-
soc_manufacturer, {plugin_dir}));
179+
soc_manufacturer, {plugin_dir},
180+
nullptr, litert_options.Get()));
155181
return R(std::move(plugin));
156182
}
157183

@@ -166,12 +192,15 @@ void Setup(const AtsConf& options) {
166192
Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup() {
167193
LITERT_ASSIGN_OR_RETURN(auto seeds, ParseParamSeedMap());
168194
LITERT_ASSIGN_OR_RETURN(auto backend, ParseBackend());
169-
std::regex neg_re(absl::GetFlag(FLAGS_dont_register),
170-
std::regex_constants::ECMAScript);
171-
std::regex pos_re(absl::GetFlag(FLAGS_do_register),
172-
std::regex_constants::ECMAScript);
195+
std::vector<std::regex> neg_re;
196+
for (const auto& re : absl::GetFlag(FLAGS_dont_register)) {
197+
neg_re.push_back(std::regex(re, std::regex_constants::ECMAScript));
198+
}
199+
std::vector<std::regex> pos_re;
200+
for (const auto& re : absl::GetFlag(FLAGS_do_register)) {
201+
pos_re.push_back(std::regex(re, std::regex_constants::ECMAScript));
202+
}
173203
auto extra_models = absl::GetFlag(FLAGS_extra_models);
174-
auto f16_range_for_f32 = absl::GetFlag(FLAGS_f16_range_for_f32);
175204
auto data_seed = absl::GetFlag(FLAGS_data_seed);
176205
auto dispatch_dir = absl::GetFlag(FLAGS_dispatch_dir);
177206
auto plugin_dir = absl::GetFlag(FLAGS_plugin_dir);
@@ -190,15 +219,19 @@ Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup() {
190219
auto limit = absl::GetFlag(FLAGS_limit);
191220
auto soc_manufacturer = absl::GetFlag(FLAGS_soc_manufacturer);
192221
auto soc_model = absl::GetFlag(FLAGS_soc_model);
222+
LITERT_ASSIGN_OR_RETURN(auto target_options, ParseOptions(backend));
223+
LITERT_ASSIGN_OR_RETURN(auto reference_options, Options::Create());
224+
reference_options.SetHardwareAccelerators(kLiteRtHwAcceleratorCpu);
193225
LITERT_ASSIGN_OR_RETURN(
194-
auto plugin, ParsePlugin(plugin_dir, soc_manufacturer, compile_mode));
226+
auto plugin,
227+
ParsePlugin(plugin_dir, soc_manufacturer, compile_mode, target_options));
195228
AtsConf res(std::move(seeds), backend, quiet, dispatch_dir, plugin_dir,
196229
std::move(neg_re), std::move(pos_re), std::move(extra_models),
197-
f16_range_for_f32, data_seed, iters_per_test,
198-
std::move(max_ms_per_test_opt), fail_on_timeout, dump_report,
199-
std::move(csv), compile_mode, std::move(models_out), limit,
200-
std::move(plugin), std::move(soc_manufacturer),
201-
std::move(soc_model));
230+
data_seed, iters_per_test, std::move(max_ms_per_test_opt),
231+
fail_on_timeout, dump_report, std::move(csv), compile_mode,
232+
std::move(models_out), limit, std::move(plugin),
233+
std::move(soc_manufacturer), std::move(soc_model),
234+
std::move(target_options), std::move(reference_options));
202235
Setup(res);
203236
return res;
204237
}
@@ -213,7 +246,15 @@ int AtsConf::GetSeedForParams(absl::string_view name) const {
213246
}
214247

215248
bool AtsConf::ShouldRegister(const std::string& name) const {
216-
return std::regex_search(name, pos_re_) && !std::regex_search(name, neg_re_);
249+
const bool include =
250+
pos_re_.empty() ||
251+
std::any_of(pos_re_.begin(), pos_re_.end(), [&name](const auto& re) {
252+
return std::regex_search(name, re);
253+
});
254+
const bool exclude = std::any_of(
255+
neg_re_.begin(), neg_re_.end(),
256+
[&name](const auto& re) { return std::regex_search(name, re); });
257+
return include && !exclude;
217258
};
218259

219260
bool AtsConf::ShouldRegister(absl::string_view name) const {

0 commit comments

Comments
 (0)