Skip to content

Commit 4bed4be

Browse files
LukeBoyercopybara-github
authored andcommitted
Scaffold the compiler flow capture for ats.
* Make capture types members of fixtures for metaprogramming * Remove the "optional" capture thing for simplicity LiteRT-PiperOrigin-RevId: 819897176
1 parent ffdf8fe commit 4bed4be

File tree

11 files changed

+187
-95
lines changed

11 files changed

+187
-95
lines changed

litert/ats/BUILD

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ cc_test(
3838
# TODO: Copybara rewrite the gunit flags and remove this tag.
3939
tags = ["no_oss"],
4040
deps = [
41-
":capture",
4241
":compile_fixture",
4342
":configure",
4443
":inference_fixture",
@@ -86,7 +85,7 @@ cc_library(
8685
testonly = True,
8786
hdrs = ["executor.h"],
8887
deps = [
89-
":capture",
88+
":inference_capture",
9089
"//litert/c:litert_common",
9190
"//litert/cc:litert_compiled_model",
9291
"//litert/cc:litert_environment",
@@ -105,9 +104,9 @@ cc_library(
105104
testonly = True,
106105
hdrs = ["register.h"],
107106
deps = [
108-
":capture",
109107
":common",
110108
":configure",
109+
":inference_capture",
111110
"//litert/c:litert_common",
112111
"//litert/c:litert_logging",
113112
"//litert/cc:litert_detail",
@@ -143,10 +142,10 @@ cc_library(
143142
testonly = True,
144143
hdrs = ["inference_fixture.h"],
145144
deps = [
146-
":capture",
147145
":common",
148146
":configure",
149147
":executor",
148+
":inference_capture",
150149
"//litert/c:litert_common",
151150
"//litert/c:litert_logging",
152151
"//litert/cc:litert_c_types_printing",
@@ -168,8 +167,8 @@ cc_library(
168167
testonly = True,
169168
hdrs = ["compile_fixture.h"],
170169
deps = [
171-
":capture",
172170
":common",
171+
":compile_capture",
173172
":configure",
174173
"//litert/c:litert_logging",
175174
"//litert/cc:litert_c_types_printing",
@@ -217,11 +216,11 @@ cc_test(
217216
"notsan",
218217
],
219218
deps = [
220-
":capture",
221219
":common",
222220
":compile_fixture",
223221
":configure",
224222
":executor",
223+
":inference_capture",
225224
":inference_fixture",
226225
":register",
227226
"//litert/c:litert_common",
@@ -245,9 +244,9 @@ cc_test(
245244
)
246245

247246
cc_library(
248-
name = "capture",
247+
name = "inference_capture",
249248
testonly = True,
250-
hdrs = ["capture.h"],
249+
hdrs = ["inference_capture.h"],
251250
deps = [
252251
":common",
253252
":print",
@@ -257,10 +256,10 @@ cc_library(
257256
)
258257

259258
cc_test(
260-
name = "capture_test",
261-
srcs = ["capture_test.cc"],
259+
name = "inference_capture_test",
260+
srcs = ["inference_capture_test.cc"],
262261
deps = [
263-
":capture",
262+
":inference_capture",
264263
"@com_google_googletest//:gtest_main",
265264
],
266265
)
@@ -297,6 +296,26 @@ cc_test(
297296
],
298297
)
299298

299+
cc_library(
300+
name = "compile_capture",
301+
testonly = True,
302+
hdrs = ["compile_capture.h"],
303+
deps = [
304+
":print",
305+
"@com_google_absl//absl/strings:string_view",
306+
],
307+
)
308+
309+
cc_test(
310+
name = "compile_capture_test",
311+
srcs = ["compile_capture_test.cc"],
312+
deps = [
313+
":compile_capture",
314+
"@com_google_absl//absl/strings",
315+
"@com_google_googletest//:gtest_main",
316+
],
317+
)
318+
300319
# PRE-CONFIGURED CTS SUITES ########################################################################
301320

302321
litert_define_ats(

litert/ats/ats.cc

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
#include <cstddef>
1616
#include <cstdint>
1717
#include <iostream>
18-
#include <optional>
1918

2019
#include <gtest/gtest.h>
2120
#include "absl/flags/parse.h" // from @com_google_absl
2221
#include "absl/log/absl_check.h" // from @com_google_absl
23-
#include "litert/ats/capture.h"
2422
#include "litert/ats/compile_fixture.h"
2523
#include "litert/ats/configure.h"
2624
#include "litert/ats/inference_fixture.h"
@@ -29,16 +27,13 @@
2927
#include "litert/c/litert_op_code.h"
3028
#include "litert/cc/litert_c_types_printing.h" // IWYU pragma: keep
3129
#include "litert/cc/litert_detail.h"
32-
#include "litert/compiler/plugin/compiler_plugin.h"
3330
#include "litert/test/generators/common.h"
3431
#include "litert/test/generators/generators.h"
3532
#include "tflite/schema/schema_generated.h"
3633

3734
namespace litert::testing {
3835
namespace {
3936

40-
using ::litert::internal::CompilerPlugin;
41-
4237
static constexpr const char* kArt = R"(
4338
### ###### ###### ######## ## ######## ######## ### ######## ####### ######## ######## ######## ###### ######## ###### ## ## #### ######## ########
4439
## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ##
@@ -51,7 +46,7 @@ static constexpr const char* kArt = R"(
5146

5247
template <typename Fixture>
5348
void RegisterNoOp(const AtsConf& options, size_t& test_id, size_t iters,
54-
AtsCapture::Ref cap) {
49+
typename Fixture::Capture& cap) {
5550
// clang-format off
5651
RegisterCombinations<
5752
Fixture,
@@ -64,7 +59,7 @@ void RegisterNoOp(const AtsConf& options, size_t& test_id, size_t iters,
6459

6560
template <typename Fixture>
6661
void RegisterBinaryNoBroadcast(const AtsConf& options, size_t& test_id,
67-
size_t iters, AtsCapture::Ref cap) {
62+
size_t iters, typename Fixture::Capture& cap) {
6863
// clang-format off
6964
RegisterCombinations<
7065
Fixture,
@@ -78,7 +73,8 @@ void RegisterBinaryNoBroadcast(const AtsConf& options, size_t& test_id,
7873
}
7974

8075
template <typename Fixture>
81-
void RegisterAll(const AtsConf& options, size_t& test_id, AtsCapture::Ref cap) {
76+
void RegisterAll(const AtsConf& options, size_t& test_id,
77+
typename Fixture::Capture& cap) {
8278
RegisterExtraModels<Fixture>(test_id, options, cap);
8379
RegisterNoOp<Fixture>(options, test_id, /*iters=*/10, cap);
8480
RegisterBinaryNoBroadcast<Fixture>(options, test_id, /*iters=*/10, cap);
@@ -91,15 +87,14 @@ int Ats() {
9187
ABSL_CHECK(options);
9288

9389
size_t test_id = 0;
94-
AtsCapture cap;
95-
96-
std::optional<CompilerPlugin> plugin = std::nullopt;
90+
typename AtsInferenceTest::Capture i_cap;
91+
typename AtsCompileTest::Capture c_cap;
9792

9893
if (!options->CompileMode()) {
9994
// TODO: lukeboyer - Add compile tests.
100-
RegisterAll<AtsInferenceTest>(*options, test_id, cap);
95+
RegisterAll<AtsInferenceTest>(*options, test_id, i_cap);
10196
} else {
102-
RegisterAll<AtsCompileTest>(*options, test_id, cap);
97+
RegisterAll<AtsCompileTest>(*options, test_id, c_cap);
10398
}
10499

105100
// Preliminary report.
@@ -111,8 +106,10 @@ int Ats() {
111106

112107
const auto res = RUN_ALL_TESTS();
113108

114-
options->Csv(cap);
115-
options->Print(cap);
109+
options->Csv(i_cap);
110+
options->Print(i_cap);
111+
options->Csv(c_cap);
112+
options->Print(c_cap);
116113

117114
return res;
118115
}

litert/ats/check_ats.cc

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
#include "absl/flags/parse.h" // from @com_google_absl
2424
#include "absl/flags/reflection.h" // from @com_google_absl
2525
#include "absl/strings/string_view.h" // from @com_google_absl
26-
#include "litert/ats/capture.h"
2726
#include "litert/ats/common.h"
2827
#include "litert/ats/compile_fixture.h"
2928
#include "litert/ats/configure.h"
3029
#include "litert/ats/executor.h"
30+
#include "litert/ats/inference_capture.h"
3131
#include "litert/ats/inference_fixture.h"
3232
#include "litert/ats/register.h"
3333
#include "litert/c/litert_common.h"
@@ -81,29 +81,30 @@ Expected<void> CheckAts() {
8181
absl::SetFlag(&FLAGS_models_out, dir.Str());
8282

8383
size_t test_id = 0;
84-
AtsCapture cap;
84+
typename AtsInferenceTest::Capture i_cap;
85+
typename AtsCompileTest::Capture c_cap;
8586

8687
// CPU
8788
LITERT_ASSIGN_OR_RETURN(auto cpu_inference_options, CpuInferenceOptions());
8889
RegisterCombinations<AtsInferenceTest, NoOp, SizeListC<1>,
8990
TypeList<float, int32_t>>(
90-
/*iters=*/1, test_id, cpu_inference_options, cap);
91+
/*iters=*/1, test_id, cpu_inference_options, i_cap);
9192
RegisterCombinations<AtsInferenceTest, BinaryNoBroadcast, SizeListC<1>,
9293
TypeList<float>,
9394
OpCodeListC<kLiteRtOpCodeTflSub, kLiteRtOpCodeTflAdd>>(
94-
/*iters=*/1, test_id, cpu_inference_options, cap);
95+
/*iters=*/1, test_id, cpu_inference_options, i_cap);
9596

9697
// NPU
9798
LITERT_ASSIGN_OR_RETURN(auto npu_inference_options, NpuInferenceOptions());
9899
RegisterCombinations<AtsInferenceTest, BinaryNoBroadcast, SizeListC<1>,
99100
TypeList<float>, OpCodeListC<kLiteRtOpCodeTflSub>>(
100-
/*iters=*/1, test_id, npu_inference_options, cap);
101+
/*iters=*/1, test_id, npu_inference_options, i_cap);
101102

102103
// Compile
103104
LITERT_ASSIGN_OR_RETURN(auto compile_options, CompileOptions());
104105
RegisterCombinations<AtsCompileTest, BinaryNoBroadcast, SizeListC<1>,
105106
TypeList<float>, OpCodeListC<kLiteRtOpCodeTflSub>>(
106-
/*iters=*/1, test_id, compile_options, cap);
107+
/*iters=*/1, test_id, compile_options, c_cap);
107108

108109
const auto* ut = ::testing::UnitTest::GetInstance();
109110
LITERT_ENSURE((ut->total_test_count() == test_id),
@@ -113,14 +114,15 @@ Expected<void> CheckAts() {
113114
LITERT_ENSURE(!RUN_ALL_TESTS(), Error(kLiteRtStatusErrorRuntimeFailure),
114115
"Failed to run all tests.");
115116

116-
const auto cap_ok = std::all_of(cap.Rows().begin(), cap.Rows().end(),
117-
[](const AtsCaptureEntry& row) {
117+
const auto cap_ok = std::all_of(i_cap.Rows().begin(), i_cap.Rows().end(),
118+
[](const InferenceCaptureEntry& row) {
118119
return row.run.status != RunStatus::kError;
119120
});
120-
LITERT_ENSURE(cap_ok && cap.Rows().size() == test_id,
121+
LITERT_ENSURE(cap_ok && i_cap.Rows().size() == test_id - 1,
121122
Error(kLiteRtStatusErrorRuntimeFailure),
122123
"Status capture contains errors.");
123-
cap.Print(std::cerr);
124+
i_cap.Print(std::cerr);
125+
i_cap.Csv(std::cerr);
124126

125127
// Check side effects.
126128

litert/ats/compile_capture.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright 2025 Google LLC.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef THIRD_PARTY_ODML_LITERT_LITERT_ATS_COMPILE_CAPTURE_H_
16+
#define THIRD_PARTY_ODML_LITERT_LITERT_ATS_COMPILE_CAPTURE_H_
17+
18+
#include <string>
19+
20+
#include "absl/strings/string_view.h" // from @com_google_absl
21+
#include "litert/ats/print.h"
22+
23+
namespace litert::testing {
24+
25+
struct CompileCaptureEntry : public PrintableRow<> {
26+
CompileCaptureEntry() = default;
27+
28+
private:
29+
Printables GetPrintables() const override { return Printables{}; }
30+
31+
std::string Name() const override { return "CompileCapture"; }
32+
};
33+
34+
// TODO: lukeboyer - Implement this and subclasses.
35+
class CompileCapture : public PrintableCollection<CompileCaptureEntry> {
36+
public:
37+
using Entry = CompileCaptureEntry;
38+
39+
private:
40+
absl::string_view Name() const override { return "Ats Compile Results"; }
41+
};
42+
43+
} // namespace litert::testing
44+
45+
#endif // THIRD_PARTY_ODML_LITERT_LITERT_ATS_COMPILE_CAPTURE_H_

litert/ats/compile_capture_test.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright 2025 Google LLC.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "litert/ats/compile_capture.h"
16+
17+
#include <sstream>
18+
19+
#include <gmock/gmock.h>
20+
#include <gtest/gtest.h>
21+
22+
namespace litert::testing {
23+
namespace {
24+
25+
using ::testing::HasSubstr;
26+
27+
TEST(AtsCompileCaptureTest, Basic) {
28+
CompileCapture cap;
29+
cap.NewEntry();
30+
31+
std::ostringstream s;
32+
cap.Print(s);
33+
34+
EXPECT_THAT(s.str(), HasSubstr("CompileCapture"));
35+
}
36+
37+
} // namespace
38+
} // namespace litert::testing

0 commit comments

Comments
 (0)