Skip to content

Commit ac42fca

Browse files
LukeBoyercopybara-github
authored andcommitted
Make extra models flag take a list of models or a dir.
LiteRT-PiperOrigin-RevId: 820354611
1 parent 9ac3eba commit ac42fca

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

litert/ats/check_ats.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ Expected<AtsConf> CompileOptions() {
7474
}
7575

7676
Expected<void> CheckAts() {
77-
absl::SetFlag(&FLAGS_extra_models, GetLiteRtPath("test/testdata/"));
77+
absl::SetFlag(&FLAGS_extra_models, {GetLiteRtPath("test/testdata/")});
7878

7979
LITERT_ASSIGN_OR_RETURN(auto dir, UniqueTestDirectory::Create());
8080
absl::SetFlag(&FLAGS_models_out, dir.Str());

litert/ats/configure.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ ABSL_FLAG(
7171
bool, f16_range_for_f32, false,
7272
"If true, will generate values f16 values stored as f32 for f32 tensors.");
7373

74-
ABSL_FLAG(std::string, extra_models, "",
75-
"Optional directory containing models which to add to the test.");
74+
ABSL_FLAG(std::vector<std::string>, extra_models, {},
75+
"Optional list of directories, or model files to add to the test.");
7676

7777
ABSL_FLAG(size_t, iters_per_test, 1, "Number of iterations per test.");
7878

litert/ats/configure.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ ABSL_DECLARE_FLAG(std::string, do_register);
6262
// Will generate values for f32 tensors in the range of f16 values.
6363
ABSL_DECLARE_FLAG(bool, f16_range_for_f32);
6464

65-
// Optional directory containing models which to add to the test.
66-
ABSL_DECLARE_FLAG(std::string, extra_models);
65+
// Optional list of directories, or model files to add to the test.
66+
ABSL_DECLARE_FLAG(std::vector<std::string>, extra_models);
6767

6868
// Number of iterations per test, each one will have different tensor data.
6969
ABSL_DECLARE_FLAG(size_t, iters_per_test);
@@ -141,11 +141,19 @@ class AtsConf {
141141

142142
// List of models to add to the test.
143143
std::vector<std::string> ExtraModels() const {
144-
auto res = internal::ListDir(extra_models_);
145-
if (!res) {
146-
return {};
144+
std::vector<std::string> res;
145+
for (const auto& model : extra_models_) {
146+
if (internal::IsDir(model)) {
147+
auto list = internal::ListDir(model);
148+
if (!list) {
149+
continue;
150+
}
151+
res.insert(res.end(), list->begin(), list->end());
152+
} else {
153+
res.push_back(model);
154+
}
147155
}
148-
return *res;
156+
return res;
149157
}
150158

151159
// Number of iterations per test, each one will have different tensor data.
@@ -233,8 +241,9 @@ class AtsConf {
233241
explicit AtsConf(SeedMap&& seeds_for_params, ExecutionBackend backend,
234242
bool quiet, std::string dispatch_dir, std::string plugin_dir,
235243
std::regex&& neg_re, std::regex&& pos_re,
236-
std::string extra_models, bool f16_range_for_f32,
237-
std::optional<int> data_seed, size_t iters_per_test,
244+
std::vector<std::string> extra_models,
245+
bool f16_range_for_f32, std::optional<int> data_seed,
246+
size_t iters_per_test,
238247
std::chrono::milliseconds max_ms_per_test,
239248
bool fail_on_timeout, bool dump_report, std::string csv,
240249
bool compile_mode, std::string models_out, int32_t limit,
@@ -274,7 +283,7 @@ class AtsConf {
274283
std::string plugin_dir_;
275284
std::regex neg_re_;
276285
std::regex pos_re_;
277-
std::string extra_models_;
286+
std::vector<std::string> extra_models_;
278287
bool f16_range_for_f32_;
279288
std::optional<int> data_seed_;
280289
size_t iters_per_test_;

0 commit comments

Comments
 (0)