Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions litert/ats/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ litert_test(
"//litert/test/generators",
"//litert/test/generators:common",
"//tflite/schema:schema_fbs",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest",
],
)
Expand All @@ -62,17 +64,25 @@ cc_library(
testonly = True,
srcs = ["configure.cc"],
hdrs = ["configure.h"],
copts = commandline_flag_copts(),
copts = commandline_flag_copts() + [
"-DINCLUDE_QUALCOMM_COMPILE_FLAGS",
"-DINCLUDE_QUALCOMM_RUNTIME_FLAGS",
"-DINCLUDE_MEDIATEK_COMPILE_FLAGS",
"-DINCLUDE_MEDIATEK_RUNTIME_FLAGS",
],
deps = [
":common",
"//litert/c:litert_common",
"//litert/c/internal:litert_logging",
"//litert/cc:litert_expected",
"//litert/cc:litert_macros",
"//litert/cc:litert_options",
"//litert/cc/internal:litert_rng",
"//litert/compiler/plugin:compiler_plugin",
"//litert/core:filesystem_testonly",
"//litert/core/model:model_serialize",
"//litert/tools/flags/vendors:mediatek_flags",
"//litert/tools/flags/vendors:qualcomm_flags",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -111,6 +121,7 @@ cc_library(
"//litert/cc:litert_expected",
"//litert/cc/internal:litert_detail",
"//litert/cc/internal:litert_rng",
"//litert/core:filesystem",
"//litert/test/generators",
],
)
Expand All @@ -120,8 +131,10 @@ cc_test(
srcs = ["executor_test.cc"],
deps = [
":executor",
"//litert/c:litert_common",
"//litert/c:litert_op_code",
"//litert/cc:litert_buffer_ref",
"//litert/cc:litert_options",
"//litert/core/model",
"//litert/test:matchers",
"//litert/test:simple_buffer",
Expand Down Expand Up @@ -186,8 +199,10 @@ litert_device_test(
defines = ["_TEST_NPU"],
deps = [
":executor",
"//litert/c:litert_common",
"//litert/c:litert_op_code",
"//litert/cc:litert_buffer_ref",
"//litert/cc:litert_options",
"//litert/core/model",
"//litert/test:matchers",
"//litert/test:simple_buffer",
Expand Down Expand Up @@ -367,7 +382,6 @@ litert_define_ats(
# "^(?:(?!npu_ats_87).)*$$",
],
extra_flags = [
"--f16_range_for_f32=true",
"--data_seed=42",
],
jit_suffix = "",
Expand Down
30 changes: 28 additions & 2 deletions litert/ats/ats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
#include <cstddef>
#include <cstdint>
#include <iostream>
#include <string>
#include <vector>

#include <gtest/gtest.h>
#include "absl/flags/flag.h" // from @com_google_absl
#include "absl/flags/parse.h" // from @com_google_absl
#include "absl/log/absl_check.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "litert/ats/compile_fixture.h"
#include "litert/ats/configure.h"
#include "litert/ats/inference_fixture.h"
Expand Down Expand Up @@ -119,7 +123,29 @@ int Ats() {
} // namespace litert::testing

int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
absl::ParseCommandLine(argc, argv);
// Shim to support repeatable flags which absl does not.
std::vector<char*> absl_flags;
static constexpr absl::string_view kDoRegisterPrefix = "--do_register=";
static constexpr absl::string_view kDontRegisterPrefix = "--dont_register=";
std::vector<std::string> do_register;
std::vector<std::string> dont_register;
for (int i = 0; i < argc; ++i) {
if (::litert::StartsWith(argv[i], kDoRegisterPrefix)) {
do_register.push_back(std::string(
absl::string_view(argv[i]).substr(kDoRegisterPrefix.size())));
} else if (::litert::StartsWith(argv[i], kDontRegisterPrefix)) {
dont_register.push_back(std::string(
absl::string_view(argv[i]).substr(kDontRegisterPrefix.size())));
} else {
absl_flags.push_back(argv[i]);
}
}

absl::SetFlag(&FLAGS_do_register, do_register);
absl::SetFlag(&FLAGS_dont_register, dont_register);

int absl_argc = absl_flags.size();
::testing::InitGoogleTest(&absl_argc, absl_flags.data());
absl::ParseCommandLine(absl_argc, absl_flags.data());
return litert::testing::Ats();
}
3 changes: 2 additions & 1 deletion litert/ats/check_ats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ Expected<void> CheckAts() {

LITERT_ASSIGN_OR_RETURN(auto exec,
NpuCompiledModelExecutor::Create(
*model, npu_inference_options.DispatchDir()));
*model, npu_inference_options.TargetOptions(),
npu_inference_options.DispatchDir()));
const auto& subgraph = *model->Subgraphs()[0];
LITERT_ASSIGN_OR_RETURN(
auto inputs, SimpleBuffer::LikeSignature(subgraph.Inputs().begin(),
Expand Down
14 changes: 8 additions & 6 deletions litert/ats/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ struct TestNames {
}

// Create with an explicit desc.
static TestNames Create(size_t test_id, absl::string_view family,
absl::string_view logic, absl::string_view test,
static TestNames Create(size_t test_id, absl::string_view fixture,
absl::string_view source, absl::string_view test,
absl::string_view report_id,
absl::string_view desc = "") {
auto suite = MakeSuite(test_id, family, logic);
return {suite, std::string(logic), std::string(desc), std::string(test)};
auto suite = MakeSuite(test_id, fixture, source);
return {suite, std::string(test), std::string(desc),
std::string(report_id)};
}

private:
Expand Down Expand Up @@ -98,7 +100,7 @@ enum class CompilationStatus {
// Timing related types.
using Clock = std::chrono::steady_clock;
using TimePoint = Clock::time_point;
using Nanoseconds = uint64_t;
using Microseconds = uint64_t;

// Which backend to use as the "actual".
enum class ExecutionBackend { kCpu, kGpu, kNpu };
Expand Down Expand Up @@ -175,7 +177,7 @@ void AbslStringify(Sink& sink, const CompilationStatus& status) {
}

template <typename Sink>
void AbslStringify(Sink& sink, const Nanoseconds& ns) {
void AbslStringify(Sink& sink, const Microseconds& ns) {
absl::Format(&sink, "%e", ns);
}

Expand Down
8 changes: 4 additions & 4 deletions litert/ats/compile_capture.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,25 @@
namespace litert::testing {

// Information about the time taken to compile the model.
class CompilationTime : public Printable<Nanoseconds> {
class CompilationTime : public Printable<Microseconds> {
public:
// Start timing.
TimePoint Start() const { return Clock::now(); }

// Stop timing and record the latency.
void Stop(const TimePoint& start) {
std::chrono::duration<Nanoseconds, std::nano> nano = Clock::now() - start;
std::chrono::duration<Microseconds, std::nano> nano = Clock::now() - start;
nanos_ = nano.count();
}

Nanoseconds Nanos() const { return nanos_; }
Microseconds Nanos() const { return nanos_; }

CompilationTime() : Printable("CompilationTime", "compile_time(ns)") {}

private:
Fields GetFields() const override { return Fields{nanos_}; }

Nanoseconds nanos_ = std::numeric_limits<Nanoseconds>::max();
Microseconds nanos_ = std::numeric_limits<Microseconds>::max();
};

// Type to hold all of the capturable information related compilation test
Expand Down
2 changes: 1 addition & 1 deletion litert/ats/compile_capture_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ TEST(AtsCompileCaptureTest, Basic) {
e.model.name = "FOO";

ASSERT_NE(e.compilation_time.Nanos(),
std::numeric_limits<Nanoseconds>::max());
std::numeric_limits<Microseconds>::max());

std::ostringstream s;
cap.Print(s);
Expand Down
73 changes: 57 additions & 16 deletions litert/ats/configure.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "litert/ats/configure.h"

#include <algorithm>
#include <chrono> // NOLINT
#include <cstddef>
#include <cstdint>
Expand All @@ -33,7 +34,10 @@
#include "litert/c/litert_common.h"
#include "litert/cc/litert_expected.h"
#include "litert/cc/litert_macros.h"
#include "litert/cc/litert_options.h"
#include "litert/compiler/plugin/compiler_plugin.h"
#include "litert/tools/flags/vendors/mediatek_flags.h" // IWYU pragma: export
#include "litert/tools/flags/vendors/qualcomm_flags.h" // IWYU pragma: export

ABSL_FLAG(std::optional<int>, data_seed, std::nullopt,
"Seed for the buffer data generation.");
Expand All @@ -58,11 +62,11 @@ ABSL_FLAG(std::string, plugin_dir, "",
"relevant for NPU.");

ABSL_FLAG(
std::string, dont_register, "^$",
std::vector<std::string>, dont_register, std::vector<std::string>{},
"Regex for test selection. This is a negative search match, if the pattern "
"can be found anywhere in the test name, it will be skipped.");

ABSL_FLAG(std::string, do_register, ".*",
ABSL_FLAG(std::vector<std::string>, do_register, std::vector<std::string>{},
"Regex for test selection. This is a positive search match, if the "
"pattern can be found anywhere in the test name, it will be run. "
"This has lower priority over the dont_register filter.");
Expand Down Expand Up @@ -113,6 +117,9 @@ namespace litert::testing {

namespace {

using ::litert::mediatek::MediatekOptionsFromFlags;
using ::litert::qualcomm::QualcommOptionsFromFlags;

Expected<AtsConf::SeedMap> ParseParamSeedMap() {
const auto seed_flags = absl::GetFlag(FLAGS_seeds);
AtsConf::SeedMap seeds;
Expand Down Expand Up @@ -143,15 +150,34 @@ Expected<ExecutionBackend> ParseBackend() {
}
}

Expected<Options> ParseOptions(ExecutionBackend backend) {
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
if (backend == ExecutionBackend::kNpu) {
if (auto qnn_opts = QualcommOptionsFromFlags()) {
options.AddOpaqueOptions(std::move(*qnn_opts));
}
if (auto mediatek_opts = MediatekOptionsFromFlags()) {
options.AddOpaqueOptions(std::move(*mediatek_opts));
}
options.SetHardwareAccelerators(kLiteRtHwAcceleratorNpu);
} else if (backend == ExecutionBackend::kCpu) {
options.SetHardwareAccelerators(kLiteRtHwAcceleratorCpu);
} else if (backend == ExecutionBackend::kGpu) {
options.SetHardwareAccelerators(kLiteRtHwAcceleratorGpu);
}
return options;
}

Expected<std::optional<internal::CompilerPlugin>> ParsePlugin(
absl::string_view plugin_dir, absl::string_view soc_manufacturer,
bool compile_mode) {
bool compile_mode, const Options& litert_options) {
using R = std::optional<internal::CompilerPlugin>;
if (!compile_mode) {
return R(std::nullopt);
}
LITERT_ASSIGN_OR_RETURN(auto plugin, internal::CompilerPlugin::FindPlugin(
soc_manufacturer, {plugin_dir}));
soc_manufacturer, {plugin_dir},
nullptr, litert_options.Get()));
return R(std::move(plugin));
}

Expand All @@ -166,12 +192,15 @@ void Setup(const AtsConf& options) {
Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup() {
LITERT_ASSIGN_OR_RETURN(auto seeds, ParseParamSeedMap());
LITERT_ASSIGN_OR_RETURN(auto backend, ParseBackend());
std::regex neg_re(absl::GetFlag(FLAGS_dont_register),
std::regex_constants::ECMAScript);
std::regex pos_re(absl::GetFlag(FLAGS_do_register),
std::regex_constants::ECMAScript);
std::vector<std::regex> neg_re;
for (const auto& re : absl::GetFlag(FLAGS_dont_register)) {
neg_re.push_back(std::regex(re, std::regex_constants::ECMAScript));
}
std::vector<std::regex> pos_re;
for (const auto& re : absl::GetFlag(FLAGS_do_register)) {
pos_re.push_back(std::regex(re, std::regex_constants::ECMAScript));
}
auto extra_models = absl::GetFlag(FLAGS_extra_models);
auto f16_range_for_f32 = absl::GetFlag(FLAGS_f16_range_for_f32);
auto data_seed = absl::GetFlag(FLAGS_data_seed);
auto dispatch_dir = absl::GetFlag(FLAGS_dispatch_dir);
auto plugin_dir = absl::GetFlag(FLAGS_plugin_dir);
Expand All @@ -190,15 +219,19 @@ Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup() {
auto limit = absl::GetFlag(FLAGS_limit);
auto soc_manufacturer = absl::GetFlag(FLAGS_soc_manufacturer);
auto soc_model = absl::GetFlag(FLAGS_soc_model);
LITERT_ASSIGN_OR_RETURN(auto target_options, ParseOptions(backend));
LITERT_ASSIGN_OR_RETURN(auto reference_options, Options::Create());
reference_options.SetHardwareAccelerators(kLiteRtHwAcceleratorCpu);
LITERT_ASSIGN_OR_RETURN(
auto plugin, ParsePlugin(plugin_dir, soc_manufacturer, compile_mode));
auto plugin,
ParsePlugin(plugin_dir, soc_manufacturer, compile_mode, target_options));
AtsConf res(std::move(seeds), backend, quiet, dispatch_dir, plugin_dir,
std::move(neg_re), std::move(pos_re), std::move(extra_models),
f16_range_for_f32, data_seed, iters_per_test,
std::move(max_ms_per_test_opt), fail_on_timeout, dump_report,
std::move(csv), compile_mode, std::move(models_out), limit,
std::move(plugin), std::move(soc_manufacturer),
std::move(soc_model));
data_seed, iters_per_test, std::move(max_ms_per_test_opt),
fail_on_timeout, dump_report, std::move(csv), compile_mode,
std::move(models_out), limit, std::move(plugin),
std::move(soc_manufacturer), std::move(soc_model),
std::move(target_options), std::move(reference_options));
Setup(res);
return res;
}
Expand All @@ -213,7 +246,15 @@ int AtsConf::GetSeedForParams(absl::string_view name) const {
}

bool AtsConf::ShouldRegister(const std::string& name) const {
return std::regex_search(name, pos_re_) && !std::regex_search(name, neg_re_);
const bool include =
pos_re_.empty() ||
std::any_of(pos_re_.begin(), pos_re_.end(), [&name](const auto& re) {
return std::regex_search(name, re);
});
const bool exclude = std::any_of(
neg_re_.begin(), neg_re_.end(),
[&name](const auto& re) { return std::regex_search(name, re); });
return include && !exclude;
};

bool AtsConf::ShouldRegister(absl::string_view name) const {
Expand Down
Loading
Loading