Skip to content

Commit 28b4dc0

Browse files
LukeBoyercopybara-github
authored andcommitted
Allow repeatable flags for test filtering in ats for convenience.
LiteRT-PiperOrigin-RevId: 826209225
1 parent 30faa6d commit 28b4dc0

File tree

17 files changed

+250
-103
lines changed

17 files changed

+250
-103
lines changed

litert/ats/BUILD

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ litert_test(
5151
"//litert/test/generators",
5252
"//litert/test/generators:common",
5353
"//tflite/schema:schema_fbs",
54+
"@com_google_absl//absl/flags:flag",
5455
"@com_google_absl//absl/flags:parse",
5556
"@com_google_absl//absl/log:absl_check",
5657
"@com_google_googletest//:gtest",
@@ -62,17 +63,25 @@ cc_library(
6263
testonly = True,
6364
srcs = ["configure.cc"],
6465
hdrs = ["configure.h"],
65-
copts = commandline_flag_copts(),
66+
copts = commandline_flag_copts() + [
67+
"-DINCLUDE_QUALCOMM_COMPILE_FLAGS",
68+
"-DINCLUDE_QUALCOMM_RUNTIME_FLAGS",
69+
"-DINCLUDE_MEDIATEK_COMPILE_FLAGS",
70+
"-DINCLUDE_MEDIATEK_RUNTIME_FLAGS",
71+
],
6672
deps = [
6773
":common",
6874
"//litert/c:litert_common",
6975
"//litert/c/internal:litert_logging",
7076
"//litert/cc:litert_expected",
7177
"//litert/cc:litert_macros",
78+
"//litert/cc:litert_options",
7279
"//litert/cc/internal:litert_rng",
7380
"//litert/compiler/plugin:compiler_plugin",
7481
"//litert/core:filesystem_testonly",
7582
"//litert/core/model:model_serialize",
83+
"//litert/tools/flags/vendors:mediatek_flags",
84+
"//litert/tools/flags/vendors:qualcomm_flags",
7685
"@com_google_absl//absl/container:flat_hash_map",
7786
"@com_google_absl//absl/flags:flag",
7887
"@com_google_absl//absl/strings",
@@ -111,6 +120,7 @@ cc_library(
111120
"//litert/cc:litert_expected",
112121
"//litert/cc/internal:litert_detail",
113122
"//litert/cc/internal:litert_rng",
123+
"//litert/core:filesystem",
114124
"//litert/test/generators",
115125
],
116126
)
@@ -120,8 +130,10 @@ cc_test(
120130
srcs = ["executor_test.cc"],
121131
deps = [
122132
":executor",
133+
"//litert/c:litert_common",
123134
"//litert/c:litert_op_code",
124135
"//litert/cc:litert_buffer_ref",
136+
"//litert/cc:litert_options",
125137
"//litert/core/model",
126138
"//litert/test:matchers",
127139
"//litert/test:simple_buffer",
@@ -186,8 +198,10 @@ litert_device_test(
186198
defines = ["_TEST_NPU"],
187199
deps = [
188200
":executor",
201+
"//litert/c:litert_common",
189202
"//litert/c:litert_op_code",
190203
"//litert/cc:litert_buffer_ref",
204+
"//litert/cc:litert_options",
191205
"//litert/core/model",
192206
"//litert/test:matchers",
193207
"//litert/test:simple_buffer",
@@ -367,7 +381,6 @@ litert_define_ats(
367381
# "^(?:(?!npu_ats_87).)*$$",
368382
],
369383
extra_flags = [
370-
"--f16_range_for_f32=true",
371384
"--data_seed=42",
372385
],
373386
jit_suffix = "",

litert/ats/ats.cc

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <iostream>
1818

1919
#include <gtest/gtest.h>
20+
#include "absl/flags/flag.h" // from @com_google_absl
2021
#include "absl/flags/parse.h" // from @com_google_absl
2122
#include "absl/log/absl_check.h" // from @com_google_absl
2223
#include "litert/ats/compile_fixture.h"
@@ -119,7 +120,29 @@ int Ats() {
119120
} // namespace litert::testing
120121

121122
int main(int argc, char** argv) {
122-
::testing::InitGoogleTest(&argc, argv);
123-
absl::ParseCommandLine(argc, argv);
123+
// Shim to support repeatable flags which absl does not.
124+
std::vector<char*> absl_flags;
125+
static constexpr absl::string_view kDoRegisterPrefix = "--do_register=";
126+
static constexpr absl::string_view kDontRegisterPrefix = "--dont_register=";
127+
std::vector<std::string> do_register;
128+
std::vector<std::string> dont_register;
129+
for (int i = 0; i < argc; ++i) {
130+
if (::litert::StartsWith(argv[i], kDoRegisterPrefix)) {
131+
do_register.push_back(std::string(
132+
absl::string_view(argv[i]).substr(kDoRegisterPrefix.size())));
133+
} else if (::litert::StartsWith(argv[i], kDontRegisterPrefix)) {
134+
dont_register.push_back(std::string(
135+
absl::string_view(argv[i]).substr(kDontRegisterPrefix.size())));
136+
} else {
137+
absl_flags.push_back(argv[i]);
138+
}
139+
}
140+
141+
absl::SetFlag(&FLAGS_do_register, do_register);
142+
absl::SetFlag(&FLAGS_dont_register, dont_register);
143+
144+
int absl_argc = absl_flags.size();
145+
::testing::InitGoogleTest(&absl_argc, absl_flags.data());
146+
absl::ParseCommandLine(absl_argc, absl_flags.data());
124147
return litert::testing::Ats();
125148
}

litert/ats/check_ats.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ Expected<void> CheckAts() {
179179

180180
LITERT_ASSIGN_OR_RETURN(auto exec,
181181
NpuCompiledModelExecutor::Create(
182-
*model, npu_inference_options.DispatchDir()));
182+
*model, npu_inference_options.TargetOptions(),
183+
npu_inference_options.DispatchDir()));
183184
const auto& subgraph = *model->Subgraphs()[0];
184185
LITERT_ASSIGN_OR_RETURN(
185186
auto inputs, SimpleBuffer::LikeSignature(subgraph.Inputs().begin(),

litert/ats/common.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@ struct TestNames {
4848
}
4949

5050
// Create with an explicit desc.
51-
static TestNames Create(size_t test_id, absl::string_view family,
52-
absl::string_view logic, absl::string_view test,
51+
static TestNames Create(size_t test_id, absl::string_view fixture,
52+
absl::string_view source, absl::string_view test,
53+
absl::string_view report_id,
5354
absl::string_view desc = "") {
54-
auto suite = MakeSuite(test_id, family, logic);
55-
return {suite, std::string(logic), std::string(desc), std::string(test)};
55+
auto suite = MakeSuite(test_id, fixture, source);
56+
return {suite, std::string(test), std::string(desc),
57+
std::string(report_id)};
5658
}
5759

5860
private:
@@ -98,7 +100,7 @@ enum class CompilationStatus {
98100
// Timing related types.
99101
using Clock = std::chrono::steady_clock;
100102
using TimePoint = Clock::time_point;
101-
using Nanoseconds = uint64_t;
103+
using Microseconds = uint64_t;
102104

103105
// Which backend to use as the "actual".
104106
enum class ExecutionBackend { kCpu, kGpu, kNpu };
@@ -175,7 +177,7 @@ void AbslStringify(Sink& sink, const CompilationStatus& status) {
175177
}
176178

177179
template <typename Sink>
178-
void AbslStringify(Sink& sink, const Nanoseconds& ns) {
180+
void AbslStringify(Sink& sink, const Microseconds& ns) {
179181
absl::Format(&sink, "%e", ns);
180182
}
181183

litert/ats/compile_capture.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,25 @@
2929
namespace litert::testing {
3030

3131
// Information about the time taken to compile the model.
32-
class CompilationTime : public Printable<Nanoseconds> {
32+
class CompilationTime : public Printable<Microseconds> {
3333
public:
3434
// Start timing.
3535
TimePoint Start() const { return Clock::now(); }
3636

3737
// Stop timing and record the latency.
3838
void Stop(const TimePoint& start) {
39-
std::chrono::duration<Nanoseconds, std::nano> nano = Clock::now() - start;
39+
std::chrono::duration<Microseconds, std::nano> nano = Clock::now() - start;
4040
nanos_ = nano.count();
4141
}
4242

43-
Nanoseconds Nanos() const { return nanos_; }
43+
Microseconds Nanos() const { return nanos_; }
4444

4545
CompilationTime() : Printable("CompilationTime", "compile_time(ns)") {}
4646

4747
private:
4848
Fields GetFields() const override { return Fields{nanos_}; }
4949

50-
Nanoseconds nanos_ = std::numeric_limits<Nanoseconds>::max();
50+
Microseconds nanos_ = std::numeric_limits<Microseconds>::max();
5151
};
5252

5353
// Type to hold all of the capturable information related compilation test

litert/ats/compile_capture_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ TEST(AtsCompileCaptureTest, Basic) {
3333
e.model.name = "FOO";
3434

3535
ASSERT_NE(e.compilation_time.Nanos(),
36-
std::numeric_limits<Nanoseconds>::max());
36+
std::numeric_limits<Microseconds>::max());
3737

3838
std::ostringstream s;
3939
cap.Print(s);

litert/ats/configure.cc

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
#include "litert/c/litert_common.h"
3434
#include "litert/cc/litert_expected.h"
3535
#include "litert/cc/litert_macros.h"
36+
#include "litert/cc/litert_options.h"
3637
#include "litert/compiler/plugin/compiler_plugin.h"
38+
#include "litert/tools/flags/vendors/mediatek_flags.h" // IWYU pragma: export
39+
#include "litert/tools/flags/vendors/qualcomm_flags.h" // IWYU pragma: export
3740

3841
ABSL_FLAG(std::optional<int>, data_seed, std::nullopt,
3942
"Seed for the buffer data generation.");
@@ -58,11 +61,11 @@ ABSL_FLAG(std::string, plugin_dir, "",
5861
"relevant for NPU.");
5962

6063
ABSL_FLAG(
61-
std::string, dont_register, "^$",
64+
std::vector<std::string>, dont_register, std::vector<std::string>{},
6265
"Regex for test selection. This is a negative search match, if the pattern "
6366
"can be found anywhere in the test name, it will be skipped.");
6467

65-
ABSL_FLAG(std::string, do_register, ".*",
68+
ABSL_FLAG(std::vector<std::string>, do_register, std::vector<std::string>{},
6669
"Regex for test selection. This is a positive search match, if the "
6770
"pattern can be found anywhere in the test name, it will be run. "
6871
"This has lower priority over the dont_register filter.");
@@ -113,6 +116,9 @@ namespace litert::testing {
113116

114117
namespace {
115118

119+
using ::litert::mediatek::MediatekOptionsFromFlags;
120+
using ::litert::qualcomm::QualcommOptionsFromFlags;
121+
116122
Expected<AtsConf::SeedMap> ParseParamSeedMap() {
117123
const auto seed_flags = absl::GetFlag(FLAGS_seeds);
118124
AtsConf::SeedMap seeds;
@@ -143,15 +149,34 @@ Expected<ExecutionBackend> ParseBackend() {
143149
}
144150
}
145151

152+
Expected<Options> ParseOptions(ExecutionBackend backend) {
153+
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
154+
if (backend == ExecutionBackend::kNpu) {
155+
if (auto qnn_opts = QualcommOptionsFromFlags()) {
156+
options.AddOpaqueOptions(std::move(*qnn_opts));
157+
}
158+
if (auto mediatek_opts = MediatekOptionsFromFlags()) {
159+
options.AddOpaqueOptions(std::move(*mediatek_opts));
160+
}
161+
options.SetHardwareAccelerators(kLiteRtHwAcceleratorNpu);
162+
} else if (backend == ExecutionBackend::kCpu) {
163+
options.SetHardwareAccelerators(kLiteRtHwAcceleratorCpu);
164+
} else if (backend == ExecutionBackend::kGpu) {
165+
options.SetHardwareAccelerators(kLiteRtHwAcceleratorGpu);
166+
}
167+
return options;
168+
}
169+
146170
Expected<std::optional<internal::CompilerPlugin>> ParsePlugin(
147171
absl::string_view plugin_dir, absl::string_view soc_manufacturer,
148-
bool compile_mode) {
172+
bool compile_mode, const Options& litert_options) {
149173
using R = std::optional<internal::CompilerPlugin>;
150174
if (!compile_mode) {
151175
return R(std::nullopt);
152176
}
153177
LITERT_ASSIGN_OR_RETURN(auto plugin, internal::CompilerPlugin::FindPlugin(
154-
soc_manufacturer, {plugin_dir}));
178+
soc_manufacturer, {plugin_dir},
179+
nullptr, litert_options.Get()));
155180
return R(std::move(plugin));
156181
}
157182

@@ -166,12 +191,15 @@ void Setup(const AtsConf& options) {
166191
Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup() {
167192
LITERT_ASSIGN_OR_RETURN(auto seeds, ParseParamSeedMap());
168193
LITERT_ASSIGN_OR_RETURN(auto backend, ParseBackend());
169-
std::regex neg_re(absl::GetFlag(FLAGS_dont_register),
170-
std::regex_constants::ECMAScript);
171-
std::regex pos_re(absl::GetFlag(FLAGS_do_register),
172-
std::regex_constants::ECMAScript);
194+
std::vector<std::regex> neg_re;
195+
for (const auto& re : absl::GetFlag(FLAGS_dont_register)) {
196+
neg_re.push_back(std::regex(re, std::regex_constants::ECMAScript));
197+
}
198+
std::vector<std::regex> pos_re;
199+
for (const auto& re : absl::GetFlag(FLAGS_do_register)) {
200+
pos_re.push_back(std::regex(re, std::regex_constants::ECMAScript));
201+
}
173202
auto extra_models = absl::GetFlag(FLAGS_extra_models);
174-
auto f16_range_for_f32 = absl::GetFlag(FLAGS_f16_range_for_f32);
175203
auto data_seed = absl::GetFlag(FLAGS_data_seed);
176204
auto dispatch_dir = absl::GetFlag(FLAGS_dispatch_dir);
177205
auto plugin_dir = absl::GetFlag(FLAGS_plugin_dir);
@@ -190,15 +218,19 @@ Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup() {
190218
auto limit = absl::GetFlag(FLAGS_limit);
191219
auto soc_manufacturer = absl::GetFlag(FLAGS_soc_manufacturer);
192220
auto soc_model = absl::GetFlag(FLAGS_soc_model);
221+
LITERT_ASSIGN_OR_RETURN(auto target_options, ParseOptions(backend));
222+
LITERT_ASSIGN_OR_RETURN(auto reference_options, Options::Create());
223+
reference_options.SetHardwareAccelerators(kLiteRtHwAcceleratorCpu);
193224
LITERT_ASSIGN_OR_RETURN(
194-
auto plugin, ParsePlugin(plugin_dir, soc_manufacturer, compile_mode));
225+
auto plugin,
226+
ParsePlugin(plugin_dir, soc_manufacturer, compile_mode, target_options));
195227
AtsConf res(std::move(seeds), backend, quiet, dispatch_dir, plugin_dir,
196228
std::move(neg_re), std::move(pos_re), std::move(extra_models),
197-
f16_range_for_f32, data_seed, iters_per_test,
198-
std::move(max_ms_per_test_opt), fail_on_timeout, dump_report,
199-
std::move(csv), compile_mode, std::move(models_out), limit,
200-
std::move(plugin), std::move(soc_manufacturer),
201-
std::move(soc_model));
229+
data_seed, iters_per_test, std::move(max_ms_per_test_opt),
230+
fail_on_timeout, dump_report, std::move(csv), compile_mode,
231+
std::move(models_out), limit, std::move(plugin),
232+
std::move(soc_manufacturer), std::move(soc_model),
233+
std::move(target_options), std::move(reference_options));
202234
Setup(res);
203235
return res;
204236
}
@@ -213,7 +245,15 @@ int AtsConf::GetSeedForParams(absl::string_view name) const {
213245
}
214246

215247
bool AtsConf::ShouldRegister(const std::string& name) const {
216-
return std::regex_search(name, pos_re_) && !std::regex_search(name, neg_re_);
248+
const bool include =
249+
pos_re_.empty() ||
250+
std::any_of(pos_re_.begin(), pos_re_.end(), [&name](const auto& re) {
251+
return std::regex_search(name, re);
252+
});
253+
const bool exclude = std::any_of(
254+
neg_re_.begin(), neg_re_.end(),
255+
[&name](const auto& re) { return std::regex_search(name, re); });
256+
return include && !exclude;
217257
};
218258

219259
bool AtsConf::ShouldRegister(absl::string_view name) const {

0 commit comments

Comments
 (0)