Skip to content

Commit 718e0a3

Browse files
LukeBoyercopybara-github
authored andcommitted
Add cpu reference to ats.
LiteRT-PiperOrigin-RevId: 820045790
1 parent d8e632c commit 718e0a3

File tree

3 files changed

+87
-66
lines changed

3 files changed

+87
-66
lines changed

litert/ats/BUILD

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,9 @@ cc_test(
211211
],
212212
deps = [
213213
":common",
214-
":compile_capture",
215214
":compile_fixture",
216215
":configure",
217216
":executor",
218-
":inference_capture",
219217
":inference_fixture",
220218
":register",
221219
"//litert/c:litert_common",

litert/ats/check_ats.cc

Lines changed: 83 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@
2424
#include "absl/flags/reflection.h" // from @com_google_absl
2525
#include "absl/strings/string_view.h" // from @com_google_absl
2626
#include "litert/ats/common.h"
27-
#include "litert/ats/compile_capture.h"
2827
#include "litert/ats/compile_fixture.h"
2928
#include "litert/ats/configure.h"
3029
#include "litert/ats/executor.h"
31-
#include "litert/ats/inference_capture.h"
3230
#include "litert/ats/inference_fixture.h"
3331
#include "litert/ats/register.h"
3432
#include "litert/c/litert_common.h"
@@ -82,30 +80,42 @@ Expected<void> CheckAts() {
8280
absl::SetFlag(&FLAGS_models_out, dir.Str());
8381

8482
size_t test_id = 0;
83+
8584
typename AtsInferenceTest::Capture i_cap;
8685
typename AtsCompileTest::Capture c_cap;
8786

88-
// CPU
8987
LITERT_ASSIGN_OR_RETURN(auto cpu_inference_options, CpuInferenceOptions());
90-
RegisterCombinations<AtsInferenceTest, NoOp, SizeListC<1>,
91-
TypeList<float, int32_t>>(
92-
/*iters=*/1, test_id, cpu_inference_options, i_cap);
93-
RegisterCombinations<AtsInferenceTest, BinaryNoBroadcast, SizeListC<1>,
94-
TypeList<float>,
95-
OpCodeListC<kLiteRtOpCodeTflSub, kLiteRtOpCodeTflAdd>>(
96-
/*iters=*/1, test_id, cpu_inference_options, i_cap);
88+
LITERT_ASSIGN_OR_RETURN(auto compile_options, CompileOptions());
89+
LITERT_ASSIGN_OR_RETURN(auto npu_inference_options, NpuInferenceOptions());
90+
91+
// CPU
92+
{
93+
RegisterExtraModels<AtsInferenceTest>(test_id, cpu_inference_options,
94+
i_cap);
95+
RegisterCombinations<AtsInferenceTest, NoOp, SizeListC<1>,
96+
TypeList<float, int32_t>>(
97+
/*iters=*/1, test_id, cpu_inference_options, i_cap);
98+
RegisterCombinations<AtsInferenceTest, BinaryNoBroadcast, SizeListC<1>,
99+
TypeList<float>,
100+
OpCodeListC<kLiteRtOpCodeTflSub, kLiteRtOpCodeTflAdd>>(
101+
/*iters=*/1, test_id, cpu_inference_options, i_cap);
102+
}
97103

98104
// NPU
99-
LITERT_ASSIGN_OR_RETURN(auto npu_inference_options, NpuInferenceOptions());
100-
RegisterCombinations<AtsInferenceTest, BinaryNoBroadcast, SizeListC<1>,
101-
TypeList<float>, OpCodeListC<kLiteRtOpCodeTflSub>>(
102-
/*iters=*/1, test_id, npu_inference_options, i_cap);
105+
106+
{
107+
RegisterCombinations<AtsInferenceTest, BinaryNoBroadcast, SizeListC<1>,
108+
TypeList<float>, OpCodeListC<kLiteRtOpCodeTflSub>>(
109+
/*iters=*/1, test_id, npu_inference_options, i_cap);
110+
}
103111

104112
// Compile
105-
LITERT_ASSIGN_OR_RETURN(auto compile_options, CompileOptions());
106-
RegisterCombinations<AtsCompileTest, BinaryNoBroadcast, SizeListC<1>,
107-
TypeList<float>, OpCodeListC<kLiteRtOpCodeTflSub>>(
108-
/*iters=*/1, test_id, compile_options, c_cap);
113+
114+
{
115+
RegisterCombinations<AtsCompileTest, BinaryNoBroadcast, SizeListC<1>,
116+
TypeList<float>, OpCodeListC<kLiteRtOpCodeTflSub>>(
117+
/*iters=*/1, test_id, compile_options, c_cap);
118+
}
109119

110120
const auto* ut = ::testing::UnitTest::GetInstance();
111121
LITERT_ENSURE((ut->total_test_count() == test_id),
@@ -115,55 +125,67 @@ Expected<void> CheckAts() {
115125
LITERT_ENSURE(!RUN_ALL_TESTS(), Error(kLiteRtStatusErrorRuntimeFailure),
116126
"Failed to run all tests.");
117127

118-
const auto i_cap_ok =
119-
std::all_of(i_cap.Rows().begin(), i_cap.Rows().end(),
120-
[](const InferenceCaptureEntry& row) {
121-
return row.run.status != RunStatus::kError;
122-
});
123-
LITERT_ENSURE(i_cap_ok && i_cap.Rows().size() == test_id - 1,
124-
Error(kLiteRtStatusErrorRuntimeFailure),
125-
"Status capture contains errors.");
126-
127-
const auto c_cap_ok = std::all_of(c_cap.Rows().begin(), c_cap.Rows().end(),
128-
[](const CompileCaptureEntry& row) {
129-
return row.compilation_detail.status !=
130-
CompilationStatus::kError;
131-
});
132-
LITERT_ENSURE(c_cap_ok && c_cap.Rows().size() == 1,
133-
Error(kLiteRtStatusErrorRuntimeFailure),
134-
"Status capture contains errors.");
128+
// Check inference capture.
129+
{
130+
const auto num_extra_models = std::count_if(
131+
i_cap.Rows().begin(), i_cap.Rows().end(), [](const auto& row) {
132+
return row.numerics.reference_type == ReferenceType::kCpu;
133+
});
134+
135+
const auto i_cap_ok = std::all_of(
136+
i_cap.Rows().begin(), i_cap.Rows().end(),
137+
[](const auto& row) { return row.run.status != RunStatus::kError; });
138+
139+
LITERT_ENSURE(
140+
i_cap_ok && i_cap.Rows().size() == test_id - 1 && num_extra_models == 1,
141+
Error(kLiteRtStatusErrorRuntimeFailure),
142+
"Status capture contains errors.");
143+
}
144+
145+
// Check compile capture.
146+
{
147+
const auto c_cap_ok = std::all_of(
148+
c_cap.Rows().begin(), c_cap.Rows().end(), [](const auto& row) {
149+
return row.compilation_detail.status != CompilationStatus::kError;
150+
});
151+
152+
LITERT_ENSURE(c_cap_ok && c_cap.Rows().size() == 1,
153+
Error(kLiteRtStatusErrorRuntimeFailure),
154+
"Status capture contains errors.");
155+
}
135156

136157
i_cap.Print(std::cerr);
137158
i_cap.Csv(std::cerr);
138159
c_cap.Print(std::cerr);
139160
c_cap.Csv(std::cerr);
140161

141-
// Check side effects.
142-
143-
LITERT_ASSIGN_OR_RETURN(auto out_files, internal::ListDir(dir.Str()));
144-
LITERT_ENSURE(out_files.size() == 1, Error(kLiteRtStatusErrorRuntimeFailure),
145-
"Unexpected number of output files.");
146-
147-
const auto& out_file = out_files.front();
148-
LITERT_ENSURE(EndsWith(out_file, ".tflite"),
149-
Error(kLiteRtStatusErrorRuntimeFailure),
150-
"Unexpected output file name.");
151-
152-
// Check output file can be ran.
153-
154-
LITERT_ASSIGN_OR_RETURN(auto model, internal::LoadModelFromFile(out_file));
155-
LITERT_ENSURE(internal::IsFullyCompiled(*model),
156-
Error(kLiteRtStatusErrorRuntimeFailure),
157-
"Model is not fully compiled.")
158-
159-
LITERT_ASSIGN_OR_RETURN(auto exec,
160-
NpuCompiledModelExecutor::Create(
161-
*model, npu_inference_options.DispatchDir()));
162-
const auto& subgraph = *model->Subgraphs()[0];
163-
LITERT_ASSIGN_OR_RETURN(auto inputs,
164-
SimpleBuffer::LikeSignature(subgraph.Inputs().begin(),
165-
subgraph.Inputs().end()));
166-
LITERT_RETURN_IF_ERROR(exec.Run(inputs));
162+
// Check post-test saved models.
163+
{
164+
LITERT_ASSIGN_OR_RETURN(auto out_files, internal::ListDir(dir.Str()));
165+
LITERT_ENSURE(out_files.size() == 1,
166+
Error(kLiteRtStatusErrorRuntimeFailure),
167+
"Unexpected number of output files.");
168+
169+
const auto& out_file = out_files.front();
170+
LITERT_ENSURE(EndsWith(out_file, ".tflite"),
171+
Error(kLiteRtStatusErrorRuntimeFailure),
172+
"Unexpected output file name.");
173+
174+
// Check compiled file can be ran.
175+
LITERT_ASSIGN_OR_RETURN(auto model, internal::LoadModelFromFile(out_file));
176+
LITERT_ENSURE(internal::IsFullyCompiled(*model),
177+
Error(kLiteRtStatusErrorRuntimeFailure),
178+
"Model is not fully compiled.")
179+
180+
LITERT_ASSIGN_OR_RETURN(auto exec,
181+
NpuCompiledModelExecutor::Create(
182+
*model, npu_inference_options.DispatchDir()));
183+
const auto& subgraph = *model->Subgraphs()[0];
184+
LITERT_ASSIGN_OR_RETURN(
185+
auto inputs, SimpleBuffer::LikeSignature(subgraph.Inputs().begin(),
186+
subgraph.Inputs().end()));
187+
LITERT_RETURN_IF_ERROR(exec.Run(inputs));
188+
}
167189

168190
return {};
169191
}

litert/ats/inference_fixture.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,7 @@ class AtsInferenceTest : public RngTest {
140140

141141
Expected<VarBuffers> Actual(const VarBuffers& inputs,
142142
CompiledModelExecutor* exec) {
143-
LITERT_ASSIGN_OR_RETURN(auto actual, exec->Run(inputs, cap_.latency));
144-
return actual;
143+
return exec->Run(inputs, cap_.latency);
145144
}
146145

147146
Expected<VarBuffers> Reference(const VarBuffers& inputs) const {
@@ -156,7 +155,9 @@ class AtsInferenceTest : public RngTest {
156155
}
157156

158157
Expected<VarBuffers> CpuReference(const VarBuffers& inputs) const {
159-
return Error(kLiteRtStatusErrorInvalidArgument, "TODO");
158+
LITERT_ASSIGN_OR_RETURN(auto exec,
159+
CpuCompiledModelExecutor::Create(Graph()));
160+
return exec.Run(inputs);
160161
}
161162

162163
Expected<VarBuffers> MakeOutputs() const {

0 commit comments

Comments
 (0)