1414
1515#include  " litert/ats/configure.h" 
1616
17+ #include  < algorithm> 
1718#include  < chrono>    //  NOLINT
1819#include  < cstddef> 
1920#include  < cstdint> 
3334#include  " litert/c/litert_common.h" 
3435#include  " litert/cc/litert_expected.h" 
3536#include  " litert/cc/litert_macros.h" 
37+ #include  " litert/cc/litert_options.h" 
3638#include  " litert/compiler/plugin/compiler_plugin.h" 
39+ #include  " litert/tools/flags/vendors/mediatek_flags.h"    //  IWYU pragma: export
40+ #include  " litert/tools/flags/vendors/qualcomm_flags.h"    //  IWYU pragma: export
3741
3842ABSL_FLAG (std::optional<int >, data_seed, std::nullopt ,
3943          " Seed for the buffer data generation."  );
@@ -58,11 +62,11 @@ ABSL_FLAG(std::string, plugin_dir, "",
5862          " relevant for NPU."  );
5963
6064ABSL_FLAG (
61-     std::string, dont_register, " ^$ "  ,
65+     std::vector<std:: string> , dont_register, std::vector<std::string>{} ,
6266    " Regex for test selection. This is a negative search match, if the pattern " 
6367    " can be found anywhere in the test name, it will be skipped."  );
6468
65- ABSL_FLAG (std::string, do_register, " .* "  ,
69+ ABSL_FLAG (std::vector<std:: string> , do_register, std::vector<std::string>{} ,
6670          " Regex for test selection. This is a positive search match, if the " 
6771          " pattern can be found anywhere in the test name, it will be run. " 
6872          " This has lower priority over the dont_register filter."  );
@@ -113,6 +117,9 @@ namespace litert::testing {
113117
114118namespace  {
115119
120+ using  ::litert::mediatek::MediatekOptionsFromFlags;
121+ using  ::litert::qualcomm::QualcommOptionsFromFlags;
122+ 
116123Expected<AtsConf::SeedMap> ParseParamSeedMap () {
117124  const  auto  seed_flags = absl::GetFlag (FLAGS_seeds);
118125  AtsConf::SeedMap seeds;
@@ -143,15 +150,34 @@ Expected<ExecutionBackend> ParseBackend() {
143150  }
144151}
145152
153+ Expected<Options> ParseOptions (ExecutionBackend backend) {
154+   LITERT_ASSIGN_OR_RETURN (auto  options, Options::Create ());
155+   if  (backend == ExecutionBackend::kNpu ) {
156+     if  (auto  qnn_opts = QualcommOptionsFromFlags ()) {
157+       options.AddOpaqueOptions (std::move (*qnn_opts));
158+     }
159+     if  (auto  mediatek_opts = MediatekOptionsFromFlags ()) {
160+       options.AddOpaqueOptions (std::move (*mediatek_opts));
161+     }
162+     options.SetHardwareAccelerators (kLiteRtHwAcceleratorNpu );
163+   } else  if  (backend == ExecutionBackend::kCpu ) {
164+     options.SetHardwareAccelerators (kLiteRtHwAcceleratorCpu );
165+   } else  if  (backend == ExecutionBackend::kGpu ) {
166+     options.SetHardwareAccelerators (kLiteRtHwAcceleratorGpu );
167+   }
168+   return  options;
169+ }
170+ 
146171Expected<std::optional<internal::CompilerPlugin>> ParsePlugin (
147172    absl::string_view plugin_dir, absl::string_view soc_manufacturer,
148-     bool  compile_mode) {
173+     bool  compile_mode,  const  Options& litert_options ) {
149174  using  R = std::optional<internal::CompilerPlugin>;
150175  if  (!compile_mode) {
151176    return  R (std::nullopt );
152177  }
153178  LITERT_ASSIGN_OR_RETURN (auto  plugin, internal::CompilerPlugin::FindPlugin (
154-                                            soc_manufacturer, {plugin_dir}));
179+                                            soc_manufacturer, {plugin_dir},
180+                                            nullptr , litert_options.Get ()));
155181  return  R (std::move (plugin));
156182}
157183
@@ -166,12 +192,15 @@ void Setup(const AtsConf& options) {
166192Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup () {
167193  LITERT_ASSIGN_OR_RETURN (auto  seeds, ParseParamSeedMap ());
168194  LITERT_ASSIGN_OR_RETURN (auto  backend, ParseBackend ());
169-   std::regex neg_re (absl::GetFlag (FLAGS_dont_register),
170-                     std::regex_constants::ECMAScript);
171-   std::regex pos_re (absl::GetFlag (FLAGS_do_register),
172-                     std::regex_constants::ECMAScript);
195+   std::vector<std::regex> neg_re;
196+   for  (const  auto & re : absl::GetFlag (FLAGS_dont_register)) {
197+     neg_re.push_back (std::regex (re, std::regex_constants::ECMAScript));
198+   }
199+   std::vector<std::regex> pos_re;
200+   for  (const  auto & re : absl::GetFlag (FLAGS_do_register)) {
201+     pos_re.push_back (std::regex (re, std::regex_constants::ECMAScript));
202+   }
173203  auto  extra_models = absl::GetFlag (FLAGS_extra_models);
174-   auto  f16_range_for_f32 = absl::GetFlag (FLAGS_f16_range_for_f32);
175204  auto  data_seed = absl::GetFlag (FLAGS_data_seed);
176205  auto  dispatch_dir = absl::GetFlag (FLAGS_dispatch_dir);
177206  auto  plugin_dir = absl::GetFlag (FLAGS_plugin_dir);
@@ -190,15 +219,19 @@ Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup() {
190219  auto  limit = absl::GetFlag (FLAGS_limit);
191220  auto  soc_manufacturer = absl::GetFlag (FLAGS_soc_manufacturer);
192221  auto  soc_model = absl::GetFlag (FLAGS_soc_model);
222+   LITERT_ASSIGN_OR_RETURN (auto  target_options, ParseOptions (backend));
223+   LITERT_ASSIGN_OR_RETURN (auto  reference_options, Options::Create ());
224+   reference_options.SetHardwareAccelerators (kLiteRtHwAcceleratorCpu );
193225  LITERT_ASSIGN_OR_RETURN (
194-       auto  plugin, ParsePlugin (plugin_dir, soc_manufacturer, compile_mode));
226+       auto  plugin,
227+       ParsePlugin (plugin_dir, soc_manufacturer, compile_mode, target_options));
195228  AtsConf res (std::move (seeds), backend, quiet, dispatch_dir, plugin_dir,
196229              std::move (neg_re), std::move (pos_re), std::move (extra_models),
197-               f16_range_for_f32,  data_seed, iters_per_test,
198-               std::move (max_ms_per_test_opt ), fail_on_timeout, dump_report ,
199-               std::move (csv ), compile_mode , std::move (models_out), limit ,
200-               std::move (plugin ), std::move (soc_manufacturer ),
201-               std::move (soc_model ));
230+               data_seed, iters_per_test,  std::move (max_ms_per_test_opt) ,
231+               fail_on_timeout, dump_report,  std::move (csv ), compile_mode ,
232+               std::move (models_out ), limit , std::move (plugin) ,
233+               std::move (soc_manufacturer ), std::move (soc_model ),
234+               std::move (target_options),  std::move (reference_options ));
202235  Setup (res);
203236  return  res;
204237}
@@ -213,7 +246,15 @@ int AtsConf::GetSeedForParams(absl::string_view name) const {
213246}
214247
215248bool  AtsConf::ShouldRegister (const  std::string& name) const  {
216-   return  std::regex_search (name, pos_re_) && !std::regex_search (name, neg_re_);
249+   const  bool  include =
250+       pos_re_.empty () ||
251+       std::any_of (pos_re_.begin (), pos_re_.end (), [&name](const  auto & re) {
252+         return  std::regex_search (name, re);
253+       });
254+   const  bool  exclude = std::any_of (
255+       neg_re_.begin (), neg_re_.end (),
256+       [&name](const  auto & re) { return  std::regex_search (name, re); });
257+   return  include && !exclude;
217258};
218259
219260bool  AtsConf::ShouldRegister (absl::string_view name) const  {
0 commit comments