Skip to content

Commit 555a7fe

Browse files
LukeBoyercopybara-github
authored andcommitted
hook up aot compile flow for ats
LiteRT-PiperOrigin-RevId: 819871594
1 parent 4a19e10 commit 555a7fe

File tree

9 files changed

+449
-134
lines changed

9 files changed

+449
-134
lines changed

litert/ats/BUILD

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,16 @@ cc_test(
3939
tags = ["no_oss"],
4040
deps = [
4141
":capture",
42+
":compile_fixture",
4243
":configure",
4344
":inference_fixture",
4445
":register",
4546
"//litert/c:litert_logging",
4647
"//litert/c:litert_op_code",
48+
"//litert/cc:litert_buffer_ref",
4749
"//litert/cc:litert_c_types_printing",
4850
"//litert/cc:litert_detail",
51+
"//litert/compiler/plugin:compiler_plugin",
4952
"//litert/test/generators",
5053
"//litert/test/generators:common",
5154
"//tflite/schema:schema_fbs",
@@ -67,7 +70,9 @@ cc_library(
6770
"//litert/cc:litert_expected",
6871
"//litert/cc:litert_macros",
6972
"//litert/cc:litert_rng",
73+
"//litert/compiler/plugin:compiler_plugin",
7074
"//litert/core:filesystem_testonly",
75+
"//litert/core/model:model_serialize",
7176
"@com_google_absl//absl/container:flat_hash_map",
7277
"@com_google_absl//absl/flags:flag",
7378
"@com_google_absl//absl/strings",
@@ -101,6 +106,7 @@ cc_library(
101106
hdrs = ["register.h"],
102107
deps = [
103108
":capture",
109+
":common",
104110
":configure",
105111
"//litert/c:litert_common",
106112
"//litert/c:litert_logging",
@@ -152,9 +158,26 @@ cc_library(
152158
"//litert/test:rng_fixture",
153159
"//litert/test:simple_buffer",
154160
"//litert/test/generators:common",
155-
"@com_google_absl//absl/strings",
156161
"@com_google_absl//absl/strings:str_format",
157-
"@com_google_absl//absl/strings:string_view",
162+
"@com_google_googletest//:gtest",
163+
],
164+
)
165+
166+
cc_library(
167+
name = "compile_fixture",
168+
testonly = True,
169+
hdrs = ["compile_fixture.h"],
170+
deps = [
171+
":capture",
172+
":common",
173+
":configure",
174+
"//litert/c:litert_logging",
175+
"//litert/cc:litert_c_types_printing",
176+
"//litert/compiler/plugin:compiler_plugin",
177+
"//litert/core:filesystem_testonly",
178+
"//litert/core/model",
179+
"//litert/test:matchers",
180+
"//litert/test/generators:common",
158181
"@com_google_googletest//:gtest",
159182
],
160183
)
@@ -196,19 +219,26 @@ cc_test(
196219
deps = [
197220
":capture",
198221
":common",
222+
":compile_fixture",
199223
":configure",
224+
":executor",
200225
":inference_fixture",
201226
":register",
227+
"//litert/c:litert_common",
202228
"//litert/c:litert_op_code",
203229
"//litert/cc:litert_detail",
204230
"//litert/cc:litert_expected",
231+
"//litert/cc:litert_macros",
232+
"//litert/core:filesystem",
233+
"//litert/core/model",
234+
"//litert/core/model:model_load",
205235
"//litert/test:common",
236+
"//litert/test:simple_buffer",
206237
"//litert/test/generators",
207238
"//litert/test/generators:common",
208239
"@com_google_absl//absl/flags:flag",
209240
"@com_google_absl//absl/flags:parse",
210-
"@com_google_absl//absl/log:absl_check",
211-
"@com_google_absl//absl/strings",
241+
"@com_google_absl//absl/flags:reflection",
212242
"@com_google_absl//absl/strings:string_view",
213243
"@com_google_googletest//:gtest",
214244
],
@@ -239,7 +269,11 @@ cc_library(
239269
name = "common",
240270
testonly = True,
241271
hdrs = ["common.h"],
242-
deps = ["@com_google_absl//absl/strings:str_format"],
272+
deps = [
273+
"//litert/core/model",
274+
"@com_google_absl//absl/strings:str_format",
275+
"@com_google_absl//absl/strings:string_view",
276+
],
243277
)
244278

245279
cc_library(

litert/ats/ats.cc

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,30 @@
1515
#include <cstddef>
1616
#include <cstdint>
1717
#include <iostream>
18+
#include <optional>
1819

1920
#include <gtest/gtest.h>
2021
#include "absl/flags/parse.h" // from @com_google_absl
2122
#include "absl/log/absl_check.h" // from @com_google_absl
2223
#include "litert/ats/capture.h"
24+
#include "litert/ats/compile_fixture.h"
2325
#include "litert/ats/configure.h"
2426
#include "litert/ats/inference_fixture.h"
2527
#include "litert/ats/register.h"
2628
#include "litert/c/litert_logging.h"
2729
#include "litert/c/litert_op_code.h"
2830
#include "litert/cc/litert_c_types_printing.h" // IWYU pragma: keep
2931
#include "litert/cc/litert_detail.h"
32+
#include "litert/compiler/plugin/compiler_plugin.h"
3033
#include "litert/test/generators/common.h"
3134
#include "litert/test/generators/generators.h"
3235
#include "tflite/schema/schema_generated.h"
3336

3437
namespace litert::testing {
3538
namespace {
3639

40+
using ::litert::internal::CompilerPlugin;
41+
3742
static constexpr const char* kArt = R"(
3843
### ###### ###### ######## ## ######## ######## ### ######## ####### ######## ######## ######## ###### ######## ###### ## ## #### ######## ########
3944
## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ##
@@ -44,23 +49,25 @@ static constexpr const char* kArt = R"(
4449
## ## ###### ###### ######## ######## ######## ## ## ## ## ## ####### ## ## ## ######## ###### ## ###### ####### #### ## ########
4550
)";
4651

52+
template <typename Fixture>
4753
void RegisterNoOp(const AtsConf& options, size_t& test_id, size_t iters,
4854
AtsCapture::Ref cap) {
4955
// clang-format off
5056
RegisterCombinations<
51-
AtsInferenceTest,
57+
Fixture,
5258
NoOp,
5359
SizeListC<1, 2, 3, 4>,
5460
TypeList<float, int32_t>>
5561
(iters, test_id, options, cap);
5662
// clang-format on
5763
}
5864

65+
template <typename Fixture>
5966
void RegisterBinaryNoBroadcast(const AtsConf& options, size_t& test_id,
6067
size_t iters, AtsCapture::Ref cap) {
6168
// clang-format off
6269
RegisterCombinations<
63-
AtsInferenceTest,
70+
Fixture,
6471
BinaryNoBroadcast,
6572
SizeListC<1, 2, 3, 4, 5, 6>,
6673
TypeList<float, int32_t>,
@@ -70,6 +77,13 @@ void RegisterBinaryNoBroadcast(const AtsConf& options, size_t& test_id,
7077
// clang-format on
7178
}
7279

80+
template <typename Fixture>
81+
void RegisterAll(const AtsConf& options, size_t& test_id, AtsCapture::Ref cap) {
82+
RegisterExtraModels<Fixture>(test_id, options, cap);
83+
RegisterNoOp<Fixture>(options, test_id, /*iters=*/10, cap);
84+
RegisterBinaryNoBroadcast<Fixture>(options, test_id, /*iters=*/10, cap);
85+
}
86+
7387
int Ats() {
7488
std::cerr << kArt << std::endl;
7589

@@ -79,9 +93,14 @@ int Ats() {
7993
size_t test_id = 0;
8094
AtsCapture cap;
8195

82-
RegisterNoOp(*options, test_id, /*iters=*/10, cap);
83-
RegisterBinaryNoBroadcast(*options, test_id, /*iters=*/10, cap);
84-
RegisterExtraModels<AtsInferenceTest>(test_id, *options, cap);
96+
std::optional<CompilerPlugin> plugin = std::nullopt;
97+
98+
if (!options->CompileMode()) {
99+
// TODO: lukeboyer - Add compile tests.
100+
RegisterAll<AtsInferenceTest>(*options, test_id, cap);
101+
} else {
102+
RegisterAll<AtsCompileTest>(*options, test_id, cap);
103+
}
85104

86105
// Preliminary report.
87106
{

0 commit comments

Comments
 (0)