3333#include " litert/c/litert_common.h"
3434#include " litert/cc/litert_expected.h"
3535#include " litert/cc/litert_macros.h"
36+ #include " litert/cc/litert_options.h"
3637#include " litert/compiler/plugin/compiler_plugin.h"
38+ #include " litert/tools/flags/vendors/mediatek_flags.h" // IWYU pragma: export
39+ #include " litert/tools/flags/vendors/qualcomm_flags.h" // IWYU pragma: export
3740
3841ABSL_FLAG (std::optional<int >, data_seed, std::nullopt ,
3942 " Seed for the buffer data generation." );
@@ -58,11 +61,11 @@ ABSL_FLAG(std::string, plugin_dir, "",
5861 " relevant for NPU." );
5962
6063ABSL_FLAG (
61- std::string, dont_register, " ^$ " ,
64+ std::vector<std:: string> , dont_register, std::vector<std::string>{} ,
6265 " Regex for test selection. This is a negative search match, if the pattern "
6366 " can be found anywhere in the test name, it will be skipped." );
6467
65- ABSL_FLAG (std::string, do_register, " .* " ,
68+ ABSL_FLAG (std::vector<std:: string> , do_register, std::vector<std::string>{} ,
6669 " Regex for test selection. This is a positive search match, if the "
6770 " pattern can be found anywhere in the test name, it will be run. "
6871 " This has lower priority over the dont_register filter." );
@@ -113,6 +116,9 @@ namespace litert::testing {
113116
114117namespace {
115118
119+ using ::litert::mediatek::MediatekOptionsFromFlags;
120+ using ::litert::qualcomm::QualcommOptionsFromFlags;
121+
116122Expected<AtsConf::SeedMap> ParseParamSeedMap () {
117123 const auto seed_flags = absl::GetFlag (FLAGS_seeds);
118124 AtsConf::SeedMap seeds;
@@ -143,15 +149,34 @@ Expected<ExecutionBackend> ParseBackend() {
143149 }
144150}
145151
152+ Expected<Options> ParseOptions (ExecutionBackend backend) {
153+ LITERT_ASSIGN_OR_RETURN (auto options, Options::Create ());
154+ if (backend == ExecutionBackend::kNpu ) {
155+ if (auto qnn_opts = QualcommOptionsFromFlags ()) {
156+ options.AddOpaqueOptions (std::move (*qnn_opts));
157+ }
158+ if (auto mediatek_opts = MediatekOptionsFromFlags ()) {
159+ options.AddOpaqueOptions (std::move (*mediatek_opts));
160+ }
161+ options.SetHardwareAccelerators (kLiteRtHwAcceleratorNpu );
162+ } else if (backend == ExecutionBackend::kCpu ) {
163+ options.SetHardwareAccelerators (kLiteRtHwAcceleratorCpu );
164+ } else if (backend == ExecutionBackend::kGpu ) {
165+ options.SetHardwareAccelerators (kLiteRtHwAcceleratorGpu );
166+ }
167+ return options;
168+ }
169+
146170Expected<std::optional<internal::CompilerPlugin>> ParsePlugin (
147171 absl::string_view plugin_dir, absl::string_view soc_manufacturer,
148- bool compile_mode) {
172+ bool compile_mode, const Options& litert_options ) {
149173 using R = std::optional<internal::CompilerPlugin>;
150174 if (!compile_mode) {
151175 return R (std::nullopt );
152176 }
153177 LITERT_ASSIGN_OR_RETURN (auto plugin, internal::CompilerPlugin::FindPlugin (
154- soc_manufacturer, {plugin_dir}));
178+ soc_manufacturer, {plugin_dir},
179+ nullptr , litert_options.Get ()));
155180 return R (std::move (plugin));
156181}
157182
@@ -166,12 +191,15 @@ void Setup(const AtsConf& options) {
166191Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup () {
167192 LITERT_ASSIGN_OR_RETURN (auto seeds, ParseParamSeedMap ());
168193 LITERT_ASSIGN_OR_RETURN (auto backend, ParseBackend ());
169- std::regex neg_re (absl::GetFlag (FLAGS_dont_register),
170- std::regex_constants::ECMAScript);
171- std::regex pos_re (absl::GetFlag (FLAGS_do_register),
172- std::regex_constants::ECMAScript);
194+ std::vector<std::regex> neg_re;
195+ for (const auto & re : absl::GetFlag (FLAGS_dont_register)) {
196+ neg_re.push_back (std::regex (re, std::regex_constants::ECMAScript));
197+ }
198+ std::vector<std::regex> pos_re;
199+ for (const auto & re : absl::GetFlag (FLAGS_do_register)) {
200+ pos_re.push_back (std::regex (re, std::regex_constants::ECMAScript));
201+ }
173202 auto extra_models = absl::GetFlag (FLAGS_extra_models);
174- auto f16_range_for_f32 = absl::GetFlag (FLAGS_f16_range_for_f32);
175203 auto data_seed = absl::GetFlag (FLAGS_data_seed);
176204 auto dispatch_dir = absl::GetFlag (FLAGS_dispatch_dir);
177205 auto plugin_dir = absl::GetFlag (FLAGS_plugin_dir);
@@ -190,15 +218,19 @@ Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup() {
190218 auto limit = absl::GetFlag (FLAGS_limit);
191219 auto soc_manufacturer = absl::GetFlag (FLAGS_soc_manufacturer);
192220 auto soc_model = absl::GetFlag (FLAGS_soc_model);
221+ LITERT_ASSIGN_OR_RETURN (auto target_options, ParseOptions (backend));
222+ LITERT_ASSIGN_OR_RETURN (auto reference_options, Options::Create ());
223+ reference_options.SetHardwareAccelerators (kLiteRtHwAcceleratorCpu );
193224 LITERT_ASSIGN_OR_RETURN (
194- auto plugin, ParsePlugin (plugin_dir, soc_manufacturer, compile_mode));
225+ auto plugin,
226+ ParsePlugin (plugin_dir, soc_manufacturer, compile_mode, target_options));
195227 AtsConf res (std::move (seeds), backend, quiet, dispatch_dir, plugin_dir,
196228 std::move (neg_re), std::move (pos_re), std::move (extra_models),
197- f16_range_for_f32, data_seed, iters_per_test,
198- std::move (max_ms_per_test_opt ), fail_on_timeout, dump_report ,
199- std::move (csv ), compile_mode , std::move (models_out), limit ,
200- std::move (plugin ), std::move (soc_manufacturer ),
201- std::move (soc_model ));
229+ data_seed, iters_per_test, std::move (max_ms_per_test_opt) ,
230+ fail_on_timeout, dump_report, std::move (csv ), compile_mode ,
231+ std::move (models_out ), limit , std::move (plugin) ,
232+ std::move (soc_manufacturer ), std::move (soc_model ),
233+ std::move (target_options), std::move (reference_options ));
202234 Setup (res);
203235 return res;
204236}
@@ -213,7 +245,15 @@ int AtsConf::GetSeedForParams(absl::string_view name) const {
213245}
214246
215247bool AtsConf::ShouldRegister (const std::string& name) const {
216- return std::regex_search (name, pos_re_) && !std::regex_search (name, neg_re_);
248+ const bool include =
249+ pos_re_.empty () ||
250+ std::any_of (pos_re_.begin (), pos_re_.end (), [&name](const auto & re) {
251+ return std::regex_search (name, re);
252+ });
253+ const bool exclude = std::any_of (
254+ neg_re_.begin (), neg_re_.end (),
255+ [&name](const auto & re) { return std::regex_search (name, re); });
256+ return include && !exclude;
217257};
218258
219259bool AtsConf::ShouldRegister (absl::string_view name) const {
0 commit comments