Skip to content

Commit d8e632c

Browse files
LukeBoyercopybara-github
authored andcommitted
Add compile data capture to ats.
* Pull out common capture logic. * Move "extra models" to generators LiteRT-PiperOrigin-RevId: 819975154
1 parent cccd495 commit d8e632c

File tree

16 files changed

+414
-148
lines changed

16 files changed

+414
-148
lines changed

litert/ats/BUILD

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,8 @@ cc_test(
4444
":register",
4545
"//litert/c:litert_logging",
4646
"//litert/c:litert_op_code",
47-
"//litert/cc:litert_buffer_ref",
4847
"//litert/cc:litert_c_types_printing",
4948
"//litert/cc:litert_detail",
50-
"//litert/compiler/plugin:compiler_plugin",
5149
"//litert/test/generators",
5250
"//litert/test/generators:common",
5351
"//tflite/schema:schema_fbs",
@@ -106,18 +104,11 @@ cc_library(
106104
deps = [
107105
":common",
108106
":configure",
109-
":inference_capture",
110-
"//litert/c:litert_common",
111107
"//litert/c:litert_logging",
112108
"//litert/cc:litert_detail",
113109
"//litert/cc:litert_expected",
114-
"//litert/cc:litert_macros",
115-
"//litert/cc:litert_model",
116110
"//litert/cc:litert_rng",
117-
"//litert/core/model:model_load",
118-
"//litert/test:simple_buffer",
119111
"//litert/test/generators",
120-
"@com_google_absl//absl/strings:string_view",
121112
],
122113
)
123114

@@ -158,6 +149,7 @@ cc_library(
158149
"//litert/test:simple_buffer",
159150
"//litert/test/generators:common",
160151
"@com_google_absl//absl/strings:str_format",
152+
"@com_google_absl//absl/strings:string_view",
161153
"@com_google_googletest//:gtest",
162154
],
163155
)
@@ -167,6 +159,7 @@ cc_library(
167159
testonly = True,
168160
hdrs = ["compile_fixture.h"],
169161
deps = [
162+
":capture_common",
170163
":common",
171164
":compile_capture",
172165
":configure",
@@ -177,6 +170,7 @@ cc_library(
177170
"//litert/core/model",
178171
"//litert/test:matchers",
179172
"//litert/test/generators:common",
173+
"@com_google_absl//absl/strings:string_view",
180174
"@com_google_googletest//:gtest",
181175
],
182176
)
@@ -217,6 +211,7 @@ cc_test(
217211
],
218212
deps = [
219213
":common",
214+
":compile_capture",
220215
":compile_fixture",
221216
":configure",
222217
":executor",
@@ -248,6 +243,7 @@ cc_library(
248243
testonly = True,
249244
hdrs = ["inference_capture.h"],
250245
deps = [
246+
":capture_common",
251247
":common",
252248
":print",
253249
"//litert/cc:litert_detail",
@@ -301,6 +297,8 @@ cc_library(
301297
testonly = True,
302298
hdrs = ["compile_capture.h"],
303299
deps = [
300+
":capture_common",
301+
":common",
304302
":print",
305303
"@com_google_absl//absl/strings:string_view",
306304
],
@@ -310,12 +308,25 @@ cc_test(
310308
name = "compile_capture_test",
311309
srcs = ["compile_capture_test.cc"],
312310
deps = [
311+
":common",
313312
":compile_capture",
314-
"@com_google_absl//absl/strings",
315313
"@com_google_googletest//:gtest_main",
316314
],
317315
)
318316

317+
cc_library(
318+
name = "capture_common",
319+
testonly = True,
320+
hdrs = ["capture_common.h"],
321+
deps = [
322+
":common",
323+
":configure",
324+
":print",
325+
"//litert/core/model",
326+
"@com_google_absl//absl/strings:string_view",
327+
],
328+
)
329+
319330
# PRE-CONFIGURED CTS SUITES ########################################################################
320331

321332
litert_define_ats(

litert/ats/ats.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,14 @@ int Ats() {
106106

107107
const auto res = RUN_ALL_TESTS();
108108

109-
options->Csv(i_cap);
110-
options->Print(i_cap);
111-
options->Csv(c_cap);
112-
options->Print(c_cap);
109+
// Final report.
110+
if (options->CompileMode()) {
111+
options->Csv(c_cap);
112+
options->Print(c_cap);
113+
} else {
114+
options->Csv(i_cap);
115+
options->Print(i_cap);
116+
}
113117

114118
return res;
115119
}

litert/ats/capture_common.h

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Copyright 2025 Google LLC.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef THIRD_PARTY_ODML_LITERT_LITERT_ATS_CAPTURE_COMMON_H_
16+
#define THIRD_PARTY_ODML_LITERT_LITERT_ATS_CAPTURE_COMMON_H_
17+
18+
#include <string>
19+
20+
#include "absl/strings/string_view.h" // from @com_google_absl
21+
#include "litert/ats/common.h"
22+
#include "litert/ats/configure.h"
23+
#include "litert/ats/print.h"
24+
#include "litert/core/model/model.h"
25+
26+
namespace litert::testing {
27+
28+
// Information about the input model.
29+
struct ModelDetail : public Printable<std::string, std::string, bool> {
30+
// File name, if in memeory only graph, an identifier of the graph.
31+
std::string name = "";
32+
// Optional description or representation of the model.
33+
std::string desc = "";
34+
// Was the input model precompiled offline?
35+
bool precompiled = false;
36+
37+
void SetFields(const TestNames& names, const LiteRtModelT& model) {
38+
name = names.report_id;
39+
desc = names.desc;
40+
precompiled = GetBuildStamp(model).has_value();
41+
}
42+
43+
ModelDetail() : Printable("ModelDetail", "name", "desc", "precompiled") {}
44+
45+
private:
46+
Fields GetFields() const override { return Fields{name, desc, precompiled}; }
47+
};
48+
49+
// Information about the accelerator used if any.
50+
struct AcceleratorDetail
51+
: public Printable<ExecutionBackend, std::string, std::string> {
52+
// The type of accelerator used.
53+
ExecutionBackend a_type = ExecutionBackend::kCpu;
54+
55+
// Only applicable in the NPU case.
56+
std::string soc_man = "n/a";
57+
std::string soc_model = "n/a";
58+
59+
void SetFields(const AtsConf& conf) {
60+
a_type = conf.Backend();
61+
if (conf.IsNpu()) {
62+
soc_man = conf.SocManufacturer();
63+
soc_model = conf.SocModel();
64+
}
65+
}
66+
67+
AcceleratorDetail()
68+
: Printable("AcceleratorDetail", "backend", "soc_man", "soc_model") {}
69+
70+
private:
71+
Fields GetFields() const override {
72+
return Fields{a_type, soc_man, soc_model};
73+
}
74+
};
75+
76+
// Information about any compilation that was done.
77+
struct CompilationDetail : public Printable<CompilationStatus> {
78+
// The status of the compilation.
79+
CompilationStatus status = CompilationStatus::kNotRequested;
80+
81+
CompilationDetail() : Printable("CompilationDetail", "status") {}
82+
83+
void SetFields(const AtsConf& conf, const LiteRtModelT& model, bool error) {
84+
if (!conf.IsNpu()) {
85+
return;
86+
}
87+
if (error) {
88+
status = CompilationStatus::kError;
89+
} else if (!internal::HasAnyCompiled(model)) {
90+
status = CompilationStatus::kNoOpsCompiled;
91+
} else if (!internal::IsFullyCompiled(model)) {
92+
status = CompilationStatus::kPartiallyCompiled;
93+
} else {
94+
status = CompilationStatus::kFullyCompiled;
95+
}
96+
}
97+
98+
private:
99+
Fields GetFields() const override { return Fields{status}; }
100+
};
101+
102+
} // namespace litert::testing
103+
104+
#endif // THIRD_PARTY_ODML_LITERT_LITERT_ATS_CAPTURE_COMMON_H_

litert/ats/check_ats.cc

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "absl/flags/reflection.h" // from @com_google_absl
2525
#include "absl/strings/string_view.h" // from @com_google_absl
2626
#include "litert/ats/common.h"
27+
#include "litert/ats/compile_capture.h"
2728
#include "litert/ats/compile_fixture.h"
2829
#include "litert/ats/configure.h"
2930
#include "litert/ats/executor.h"
@@ -114,15 +115,28 @@ Expected<void> CheckAts() {
114115
LITERT_ENSURE(!RUN_ALL_TESTS(), Error(kLiteRtStatusErrorRuntimeFailure),
115116
"Failed to run all tests.");
116117

117-
const auto cap_ok = std::all_of(i_cap.Rows().begin(), i_cap.Rows().end(),
118-
[](const InferenceCaptureEntry& row) {
119-
return row.run.status != RunStatus::kError;
120-
});
121-
LITERT_ENSURE(cap_ok && i_cap.Rows().size() == test_id - 1,
118+
const auto i_cap_ok =
119+
std::all_of(i_cap.Rows().begin(), i_cap.Rows().end(),
120+
[](const InferenceCaptureEntry& row) {
121+
return row.run.status != RunStatus::kError;
122+
});
123+
LITERT_ENSURE(i_cap_ok && i_cap.Rows().size() == test_id - 1,
122124
Error(kLiteRtStatusErrorRuntimeFailure),
123125
"Status capture contains errors.");
126+
127+
const auto c_cap_ok = std::all_of(c_cap.Rows().begin(), c_cap.Rows().end(),
128+
[](const CompileCaptureEntry& row) {
129+
return row.compilation_detail.status !=
130+
CompilationStatus::kError;
131+
});
132+
LITERT_ENSURE(c_cap_ok && c_cap.Rows().size() == 1,
133+
Error(kLiteRtStatusErrorRuntimeFailure),
134+
"Status capture contains errors.");
135+
124136
i_cap.Print(std::cerr);
125137
i_cap.Csv(std::cerr);
138+
c_cap.Print(std::cerr);
139+
c_cap.Csv(std::cerr);
126140

127141
// Check side effects.
128142

litert/ats/common.h

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ struct TestNames {
3939

4040
// Create using repr of ops as desc. Only use if the model has 1-ish ops.
4141
static TestNames Create(size_t test_id, absl::string_view family,
42-
const LiteRtModelT& graph) {
43-
auto suite = MakeSuite(test_id, family);
42+
absl::string_view logic, const LiteRtModelT& graph) {
43+
auto suite = MakeSuite(test_id, family, logic);
4444
auto test = absl::StrFormat("%v", graph.Subgraph(0).Ops());
4545
auto desc = test;
4646
auto report_id = suite;
@@ -49,14 +49,16 @@ struct TestNames {
4949

5050
// Create with an explicit desc.
5151
static TestNames Create(size_t test_id, absl::string_view family,
52-
absl::string_view test, absl::string_view desc = "") {
53-
auto suite = MakeSuite(test_id, family);
54-
return {suite, std::string(test), std::string(desc), std::string(test)};
52+
absl::string_view logic, absl::string_view test,
53+
absl::string_view desc = "") {
54+
auto suite = MakeSuite(test_id, family, logic);
55+
return {suite, std::string(logic), std::string(desc), std::string(test)};
5556
}
5657

5758
private:
58-
static std::string MakeSuite(size_t test_id, absl::string_view family) {
59-
return absl::StrFormat("ats_%lu_%s", test_id, family);
59+
static std::string MakeSuite(size_t test_id, absl::string_view family,
60+
absl::string_view logic) {
61+
return absl::StrFormat("ats_%lu_%s_%s", test_id, family, logic);
6062
}
6163
};
6264

@@ -80,6 +82,19 @@ enum class RunStatus {
8082
kTimeout,
8183
};
8284

85+
enum class CompilationStatus {
86+
// End never recorded.
87+
kNotRequested,
88+
// The compilation failed due to an error.
89+
kError,
90+
// Compilation succeeded, but no ops were compiled.
91+
kNoOpsCompiled,
92+
// Compilation succeeded, not all ops were compiled.
93+
kPartiallyCompiled,
94+
// Compilation succeeded, all ops were compiled.
95+
kFullyCompiled,
96+
};
97+
8398
// Timing related types.
8499
using Clock = std::chrono::steady_clock;
85100
using TimePoint = Clock::time_point;
@@ -138,6 +153,27 @@ void AbslStringify(Sink& sink, const RunStatus& status) {
138153
}
139154
}
140155

156+
template <typename Sink>
157+
void AbslStringify(Sink& sink, const CompilationStatus& status) {
158+
switch (status) {
159+
case CompilationStatus::kNotRequested:
160+
sink.Append("not_requested");
161+
break;
162+
case CompilationStatus::kError:
163+
sink.Append("error");
164+
break;
165+
case CompilationStatus::kNoOpsCompiled:
166+
sink.Append("no_ops_compiled");
167+
break;
168+
case CompilationStatus::kPartiallyCompiled:
169+
sink.Append("partially_compiled");
170+
break;
171+
case CompilationStatus::kFullyCompiled:
172+
sink.Append("fully_compiled");
173+
break;
174+
}
175+
}
176+
141177
template <typename Sink>
142178
void AbslStringify(Sink& sink, const Nanoseconds& ns) {
143179
absl::Format(&sink, "%e", ns);

0 commit comments

Comments
 (0)