Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions litert/ats/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,25 @@ cc_library(
testonly = True,
srcs = ["configure.cc"],
hdrs = ["configure.h"],
copts = commandline_flag_copts(),
copts = commandline_flag_copts() + [
"-DINCLUDE_QUALCOMM_COMPILE_FLAGS",
"-DINCLUDE_QUALCOMM_RUNTIME_FLAGS",
"-DINCLUDE_MEDIATEK_COMPILE_FLAGS",
"-DINCLUDE_MEDIATEK_RUNTIME_FLAGS",
],
deps = [
":common",
"//litert/c:litert_common",
"//litert/c/internal:litert_logging",
"//litert/cc:litert_expected",
"//litert/cc:litert_macros",
"//litert/cc:litert_options",
"//litert/cc/internal:litert_rng",
"//litert/compiler/plugin:compiler_plugin",
"//litert/core:filesystem_testonly",
"//litert/core/model:model_serialize",
"//litert/tools/flags/vendors:mediatek_flags",
"//litert/tools/flags/vendors:qualcomm_flags",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -111,6 +119,7 @@ cc_library(
"//litert/cc:litert_expected",
"//litert/cc/internal:litert_detail",
"//litert/cc/internal:litert_rng",
"//litert/core:filesystem",
"//litert/test/generators",
],
)
Expand All @@ -120,8 +129,10 @@ cc_test(
srcs = ["executor_test.cc"],
deps = [
":executor",
"//litert/c:litert_common",
"//litert/c:litert_op_code",
"//litert/cc:litert_buffer_ref",
"//litert/cc:litert_options",
"//litert/core/model",
"//litert/test:matchers",
"//litert/test:simple_buffer",
Expand Down Expand Up @@ -186,8 +197,10 @@ litert_device_test(
defines = ["_TEST_NPU"],
deps = [
":executor",
"//litert/c:litert_common",
"//litert/c:litert_op_code",
"//litert/cc:litert_buffer_ref",
"//litert/cc:litert_options",
"//litert/core/model",
"//litert/test:matchers",
"//litert/test:simple_buffer",
Expand Down Expand Up @@ -367,7 +380,6 @@ litert_define_ats(
# "^(?:(?!npu_ats_87).)*$$",
],
extra_flags = [
"--f16_range_for_f32=true",
"--data_seed=42",
],
jit_suffix = "",
Expand Down
3 changes: 2 additions & 1 deletion litert/ats/check_ats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ Expected<void> CheckAts() {

LITERT_ASSIGN_OR_RETURN(auto exec,
NpuCompiledModelExecutor::Create(
*model, npu_inference_options.DispatchDir()));
*model, npu_inference_options.TargetOptions(),
npu_inference_options.DispatchDir()));
const auto& subgraph = *model->Subgraphs()[0];
LITERT_ASSIGN_OR_RETURN(
auto inputs, SimpleBuffer::LikeSignature(subgraph.Inputs().begin(),
Expand Down
14 changes: 8 additions & 6 deletions litert/ats/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ struct TestNames {
}

// Create with an explicit desc.
static TestNames Create(size_t test_id, absl::string_view family,
absl::string_view logic, absl::string_view test,
static TestNames Create(size_t test_id, absl::string_view fixture,
absl::string_view source, absl::string_view test,
absl::string_view report_id,
absl::string_view desc = "") {
auto suite = MakeSuite(test_id, family, logic);
return {suite, std::string(logic), std::string(desc), std::string(test)};
auto suite = MakeSuite(test_id, fixture, source);
return {suite, std::string(test), std::string(desc),
std::string(report_id)};
}

private:
Expand Down Expand Up @@ -98,7 +100,7 @@ enum class CompilationStatus {
// Timing related types.
using Clock = std::chrono::steady_clock;
using TimePoint = Clock::time_point;
using Nanoseconds = uint64_t;
using Microseconds = uint64_t;

// Which backend to use as the "actual".
enum class ExecutionBackend { kCpu, kGpu, kNpu };
Expand Down Expand Up @@ -175,7 +177,7 @@ void AbslStringify(Sink& sink, const CompilationStatus& status) {
}

template <typename Sink>
void AbslStringify(Sink& sink, const Nanoseconds& ns) {
void AbslStringify(Sink& sink, const Microseconds& ns) {
absl::Format(&sink, "%e", ns);
}

Expand Down
8 changes: 4 additions & 4 deletions litert/ats/compile_capture.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,25 @@
namespace litert::testing {

// Information about the time taken to compile the model.
class CompilationTime : public Printable<Nanoseconds> {
class CompilationTime : public Printable<Microseconds> {
public:
// Start timing.
TimePoint Start() const { return Clock::now(); }

// Stop timing and record the latency.
void Stop(const TimePoint& start) {
std::chrono::duration<Nanoseconds, std::nano> nano = Clock::now() - start;
std::chrono::duration<Microseconds, std::nano> nano = Clock::now() - start;
nanos_ = nano.count();
}

Nanoseconds Nanos() const { return nanos_; }
Microseconds Nanos() const { return nanos_; }

CompilationTime() : Printable("CompilationTime", "compile_time(ns)") {}

private:
Fields GetFields() const override { return Fields{nanos_}; }

Nanoseconds nanos_ = std::numeric_limits<Nanoseconds>::max();
Microseconds nanos_ = std::numeric_limits<Microseconds>::max();
};

// Type to hold all of the capturable information related compilation test
Expand Down
2 changes: 1 addition & 1 deletion litert/ats/compile_capture_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ TEST(AtsCompileCaptureTest, Basic) {
e.model.name = "FOO";

ASSERT_NE(e.compilation_time.Nanos(),
std::numeric_limits<Nanoseconds>::max());
std::numeric_limits<Microseconds>::max());

std::ostringstream s;
cap.Print(s);
Expand Down
46 changes: 37 additions & 9 deletions litert/ats/configure.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
#include "litert/c/litert_common.h"
#include "litert/cc/litert_expected.h"
#include "litert/cc/litert_macros.h"
#include "litert/cc/litert_options.h"
#include "litert/compiler/plugin/compiler_plugin.h"
#include "litert/tools/flags/vendors/mediatek_flags.h" // IWYU pragma: export
#include "litert/tools/flags/vendors/qualcomm_flags.h" // IWYU pragma: export

ABSL_FLAG(std::optional<int>, data_seed, std::nullopt,
"Seed for the buffer data generation.");
Expand Down Expand Up @@ -113,6 +116,9 @@ namespace litert::testing {

namespace {

using ::litert::mediatek::MediatekOptionsFromFlags;
using ::litert::qualcomm::QualcommOptionsFromFlags;

Expected<AtsConf::SeedMap> ParseParamSeedMap() {
const auto seed_flags = absl::GetFlag(FLAGS_seeds);
AtsConf::SeedMap seeds;
Expand Down Expand Up @@ -143,15 +149,34 @@ Expected<ExecutionBackend> ParseBackend() {
}
}

Expected<Options> ParseOptions(ExecutionBackend backend) {
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
if (backend == ExecutionBackend::kNpu) {
if (auto qnn_opts = QualcommOptionsFromFlags()) {
options.AddOpaqueOptions(std::move(*qnn_opts));
}
if (auto mediatek_opts = MediatekOptionsFromFlags()) {
options.AddOpaqueOptions(std::move(*mediatek_opts));
}
options.SetHardwareAccelerators(kLiteRtHwAcceleratorNpu);
} else if (backend == ExecutionBackend::kCpu) {
options.SetHardwareAccelerators(kLiteRtHwAcceleratorCpu);
} else if (backend == ExecutionBackend::kGpu) {
options.SetHardwareAccelerators(kLiteRtHwAcceleratorGpu);
}
return options;
}

Expected<std::optional<internal::CompilerPlugin>> ParsePlugin(
absl::string_view plugin_dir, absl::string_view soc_manufacturer,
bool compile_mode) {
bool compile_mode, const Options& litert_options) {
using R = std::optional<internal::CompilerPlugin>;
if (!compile_mode) {
return R(std::nullopt);
}
LITERT_ASSIGN_OR_RETURN(auto plugin, internal::CompilerPlugin::FindPlugin(
soc_manufacturer, {plugin_dir}));
soc_manufacturer, {plugin_dir},
nullptr, litert_options.Get()));
return R(std::move(plugin));
}

Expand All @@ -171,7 +196,6 @@ Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup() {
std::regex pos_re(absl::GetFlag(FLAGS_do_register),
std::regex_constants::ECMAScript);
auto extra_models = absl::GetFlag(FLAGS_extra_models);
auto f16_range_for_f32 = absl::GetFlag(FLAGS_f16_range_for_f32);
auto data_seed = absl::GetFlag(FLAGS_data_seed);
auto dispatch_dir = absl::GetFlag(FLAGS_dispatch_dir);
auto plugin_dir = absl::GetFlag(FLAGS_plugin_dir);
Expand All @@ -190,15 +214,19 @@ Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup() {
auto limit = absl::GetFlag(FLAGS_limit);
auto soc_manufacturer = absl::GetFlag(FLAGS_soc_manufacturer);
auto soc_model = absl::GetFlag(FLAGS_soc_model);
LITERT_ASSIGN_OR_RETURN(auto target_options, ParseOptions(backend));
LITERT_ASSIGN_OR_RETURN(auto reference_options, Options::Create());
reference_options.SetHardwareAccelerators(kLiteRtHwAcceleratorCpu);
LITERT_ASSIGN_OR_RETURN(
auto plugin, ParsePlugin(plugin_dir, soc_manufacturer, compile_mode));
auto plugin,
ParsePlugin(plugin_dir, soc_manufacturer, compile_mode, target_options));
AtsConf res(std::move(seeds), backend, quiet, dispatch_dir, plugin_dir,
std::move(neg_re), std::move(pos_re), std::move(extra_models),
f16_range_for_f32, data_seed, iters_per_test,
std::move(max_ms_per_test_opt), fail_on_timeout, dump_report,
std::move(csv), compile_mode, std::move(models_out), limit,
std::move(plugin), std::move(soc_manufacturer),
std::move(soc_model));
data_seed, iters_per_test, std::move(max_ms_per_test_opt),
fail_on_timeout, dump_report, std::move(csv), compile_mode,
std::move(models_out), limit, std::move(plugin),
std::move(soc_manufacturer), std::move(soc_model),
std::move(target_options), std::move(reference_options));
Setup(res);
return res;
}
Expand Down
33 changes: 21 additions & 12 deletions litert/ats/configure.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@
#include "litert/ats/common.h"
#include "litert/cc/internal/litert_rng.h"
#include "litert/cc/litert_expected.h"
#include "litert/cc/litert_options.h"
#include "litert/compiler/plugin/compiler_plugin.h"
#include "litert/core/filesystem.h"
#include "litert/core/model/model_serialize.h"
#include "litert/tools/flags/vendors/mediatek_flags.h" // IWYU pragma: export
#include "litert/tools/flags/vendors/qualcomm_flags.h" // IWYU pragma: export

// Seed for the data generation.
ABSL_DECLARE_FLAG(std::optional<int>, data_seed);
Expand Down Expand Up @@ -59,9 +62,6 @@ ABSL_DECLARE_FLAG(std::string, dont_register);
// Regex for explicit inclusions.
ABSL_DECLARE_FLAG(std::string, do_register);

// Will generate values for f32 tensors in the range of f16 values.
ABSL_DECLARE_FLAG(bool, f16_range_for_f32);

// Optional list of directories, or model files to add to the test.
ABSL_DECLARE_FLAG(std::vector<std::string>, extra_models);

Expand Down Expand Up @@ -232,6 +232,12 @@ class AtsConf {
// compilation.
const std::string& SocModel() const { return soc_model_; }

// Litert options to use for the target backend.
const Options& TargetOptions() const { return target_options_; }

// Litert options to use for the reference backend.
const Options& ReferenceOptions() const { return reference_options_; }

AtsConf(const AtsConf&) = delete;
AtsConf& operator=(const AtsConf&) = delete;
AtsConf(AtsConf&&) = default;
Expand All @@ -242,13 +248,13 @@ class AtsConf {
bool quiet, std::string dispatch_dir, std::string plugin_dir,
std::regex&& neg_re, std::regex&& pos_re,
std::vector<std::string> extra_models,
bool f16_range_for_f32, std::optional<int> data_seed,
size_t iters_per_test,
std::optional<int> data_seed, size_t iters_per_test,
std::chrono::milliseconds max_ms_per_test,
bool fail_on_timeout, bool dump_report, std::string csv,
bool compile_mode, std::string models_out, int32_t limit,
std::optional<internal::CompilerPlugin> plugin,
std::string soc_manufacturer, std::string soc_model)
std::string soc_manufacturer, std::string soc_model,
Options&& target_options, Options&& reference_options)
: seeds_for_params_(std::move(seeds_for_params)),
backend_(backend),
quiet_(quiet),
Expand All @@ -257,7 +263,7 @@ class AtsConf {
neg_re_(std::move(neg_re)),
pos_re_(std::move(pos_re)),
extra_models_(std::move(extra_models)),
f16_range_for_f32_(f16_range_for_f32),

data_seed_(data_seed),
iters_per_test_(iters_per_test),
max_ms_per_test_(std::move(max_ms_per_test)),
Expand All @@ -269,10 +275,12 @@ class AtsConf {
limit_(limit),
plugin_(std::move(plugin)),
soc_manufacturer_(std::move(soc_manufacturer)),
soc_model_(std::move(soc_model)) {
if (f16_range_for_f32_) {
data_builder_.SetF16InF32();
}
soc_model_(std::move(soc_model)),
target_options_(std::move(target_options)),
reference_options_(std::move(reference_options)) {
// For now, we will provide default settings for data generation.
// More configurability may be introduced later.
data_builder_.SetSin();
}

SeedMap seeds_for_params_;
Expand All @@ -284,7 +292,6 @@ class AtsConf {
std::regex neg_re_;
std::regex pos_re_;
std::vector<std::string> extra_models_;
bool f16_range_for_f32_;
std::optional<int> data_seed_;
size_t iters_per_test_;
std::chrono::milliseconds max_ms_per_test_;
Expand All @@ -297,6 +304,8 @@ class AtsConf {
std::optional<internal::CompilerPlugin> plugin_;
std::string soc_manufacturer_;
std::string soc_model_;
Options target_options_;
Options reference_options_;

RandomTensorDataBuilder data_builder_;
};
Expand Down
Loading
Loading