2424#include " absl/flags/reflection.h" // from @com_google_absl
2525#include " absl/strings/string_view.h" // from @com_google_absl
2626#include " litert/ats/common.h"
27- #include " litert/ats/compile_capture.h"
2827#include " litert/ats/compile_fixture.h"
2928#include " litert/ats/configure.h"
3029#include " litert/ats/executor.h"
31- #include " litert/ats/inference_capture.h"
3230#include " litert/ats/inference_fixture.h"
3331#include " litert/ats/register.h"
3432#include " litert/c/litert_common.h"
@@ -82,30 +80,42 @@ Expected<void> CheckAts() {
8280 absl::SetFlag (&FLAGS_models_out, dir.Str ());
8381
8482 size_t test_id = 0 ;
83+
8584 typename AtsInferenceTest::Capture i_cap;
8685 typename AtsCompileTest::Capture c_cap;
8786
88- // CPU
8987 LITERT_ASSIGN_OR_RETURN (auto cpu_inference_options, CpuInferenceOptions ());
90- RegisterCombinations<AtsInferenceTest, NoOp, SizeListC<1 >,
91- TypeList<float , int32_t >>(
92- /* iters=*/ 1 , test_id, cpu_inference_options, i_cap);
93- RegisterCombinations<AtsInferenceTest, BinaryNoBroadcast, SizeListC<1 >,
94- TypeList<float >,
95- OpCodeListC<kLiteRtOpCodeTflSub , kLiteRtOpCodeTflAdd >>(
96- /* iters=*/ 1 , test_id, cpu_inference_options, i_cap);
88+ LITERT_ASSIGN_OR_RETURN (auto compile_options, CompileOptions ());
89+ LITERT_ASSIGN_OR_RETURN (auto npu_inference_options, NpuInferenceOptions ());
90+
91+ // CPU
92+ {
93+ RegisterExtraModels<AtsInferenceTest>(test_id, cpu_inference_options,
94+ i_cap);
95+ RegisterCombinations<AtsInferenceTest, NoOp, SizeListC<1 >,
96+ TypeList<float , int32_t >>(
97+ /* iters=*/ 1 , test_id, cpu_inference_options, i_cap);
98+ RegisterCombinations<AtsInferenceTest, BinaryNoBroadcast, SizeListC<1 >,
99+ TypeList<float >,
100+ OpCodeListC<kLiteRtOpCodeTflSub , kLiteRtOpCodeTflAdd >>(
101+ /* iters=*/ 1 , test_id, cpu_inference_options, i_cap);
102+ }
97103
98104 // NPU
99- LITERT_ASSIGN_OR_RETURN (auto npu_inference_options, NpuInferenceOptions ());
100- RegisterCombinations<AtsInferenceTest, BinaryNoBroadcast, SizeListC<1 >,
101- TypeList<float >, OpCodeListC<kLiteRtOpCodeTflSub >>(
102- /* iters=*/ 1 , test_id, npu_inference_options, i_cap);
105+
106+ {
107+ RegisterCombinations<AtsInferenceTest, BinaryNoBroadcast, SizeListC<1 >,
108+ TypeList<float >, OpCodeListC<kLiteRtOpCodeTflSub >>(
109+ /* iters=*/ 1 , test_id, npu_inference_options, i_cap);
110+ }
103111
104112 // Compile
105- LITERT_ASSIGN_OR_RETURN (auto compile_options, CompileOptions ());
106- RegisterCombinations<AtsCompileTest, BinaryNoBroadcast, SizeListC<1 >,
107- TypeList<float >, OpCodeListC<kLiteRtOpCodeTflSub >>(
108- /* iters=*/ 1 , test_id, compile_options, c_cap);
113+
114+ {
115+ RegisterCombinations<AtsCompileTest, BinaryNoBroadcast, SizeListC<1 >,
116+ TypeList<float >, OpCodeListC<kLiteRtOpCodeTflSub >>(
117+ /* iters=*/ 1 , test_id, compile_options, c_cap);
118+ }
109119
110120 const auto * ut = ::testing::UnitTest::GetInstance ();
111121 LITERT_ENSURE ((ut->total_test_count () == test_id),
@@ -115,55 +125,67 @@ Expected<void> CheckAts() {
115125 LITERT_ENSURE (!RUN_ALL_TESTS (), Error (kLiteRtStatusErrorRuntimeFailure ),
116126 " Failed to run all tests." );
117127
118- const auto i_cap_ok =
119- std::all_of (i_cap.Rows ().begin (), i_cap.Rows ().end (),
120- [](const InferenceCaptureEntry& row) {
121- return row.run .status != RunStatus::kError ;
122- });
123- LITERT_ENSURE (i_cap_ok && i_cap.Rows ().size () == test_id - 1 ,
124- Error (kLiteRtStatusErrorRuntimeFailure ),
125- " Status capture contains errors." );
126-
127- const auto c_cap_ok = std::all_of (c_cap.Rows ().begin (), c_cap.Rows ().end (),
128- [](const CompileCaptureEntry& row) {
129- return row.compilation_detail .status !=
130- CompilationStatus::kError ;
131- });
132- LITERT_ENSURE (c_cap_ok && c_cap.Rows ().size () == 1 ,
133- Error (kLiteRtStatusErrorRuntimeFailure ),
134- " Status capture contains errors." );
128+ // Check inference capture.
129+ {
130+ const auto num_extra_models = std::count_if (
131+ i_cap.Rows ().begin (), i_cap.Rows ().end (), [](const auto & row) {
132+ return row.numerics .reference_type == ReferenceType::kCpu ;
133+ });
134+
135+ const auto i_cap_ok = std::all_of (
136+ i_cap.Rows ().begin (), i_cap.Rows ().end (),
137+ [](const auto & row) { return row.run .status != RunStatus::kError ; });
138+
139+ LITERT_ENSURE (
140+ i_cap_ok && i_cap.Rows ().size () == test_id - 1 && num_extra_models == 1 ,
141+ Error (kLiteRtStatusErrorRuntimeFailure ),
142+ " Status capture contains errors." );
143+ }
144+
145+ // Check compile capture.
146+ {
147+ const auto c_cap_ok = std::all_of (
148+ c_cap.Rows ().begin (), c_cap.Rows ().end (), [](const auto & row) {
149+ return row.compilation_detail .status != CompilationStatus::kError ;
150+ });
151+
152+ LITERT_ENSURE (c_cap_ok && c_cap.Rows ().size () == 1 ,
153+ Error (kLiteRtStatusErrorRuntimeFailure ),
154+ " Status capture contains errors." );
155+ }
135156
136157 i_cap.Print (std::cerr);
137158 i_cap.Csv (std::cerr);
138159 c_cap.Print (std::cerr);
139160 c_cap.Csv (std::cerr);
140161
141- // Check side effects.
142-
143- LITERT_ASSIGN_OR_RETURN (auto out_files, internal::ListDir (dir.Str ()));
144- LITERT_ENSURE (out_files.size () == 1 , Error (kLiteRtStatusErrorRuntimeFailure ),
145- " Unexpected number of output files." );
146-
147- const auto & out_file = out_files.front ();
148- LITERT_ENSURE (EndsWith (out_file, " .tflite" ),
149- Error (kLiteRtStatusErrorRuntimeFailure ),
150- " Unexpected output file name." );
151-
152- // Check output file can be ran.
153-
154- LITERT_ASSIGN_OR_RETURN (auto model, internal::LoadModelFromFile (out_file));
155- LITERT_ENSURE (internal::IsFullyCompiled (*model),
156- Error (kLiteRtStatusErrorRuntimeFailure ),
157- " Model is not fully compiled." )
158-
159- LITERT_ASSIGN_OR_RETURN (auto exec,
160- NpuCompiledModelExecutor::Create (
161- *model, npu_inference_options.DispatchDir ()));
162- const auto & subgraph = *model->Subgraphs ()[0 ];
163- LITERT_ASSIGN_OR_RETURN (auto inputs,
164- SimpleBuffer::LikeSignature (subgraph.Inputs ().begin (),
165- subgraph.Inputs ().end ()));
166- LITERT_RETURN_IF_ERROR (exec.Run (inputs));
162+ // Check post-test saved models.
163+ {
164+ LITERT_ASSIGN_OR_RETURN (auto out_files, internal::ListDir (dir.Str ()));
165+ LITERT_ENSURE (out_files.size () == 1 ,
166+ Error (kLiteRtStatusErrorRuntimeFailure ),
167+ " Unexpected number of output files." );
168+
169+ const auto & out_file = out_files.front ();
170+ LITERT_ENSURE (EndsWith (out_file, " .tflite" ),
171+ Error (kLiteRtStatusErrorRuntimeFailure ),
172+ " Unexpected output file name." );
173+
174+ // Check compiled file can be ran.
175+ LITERT_ASSIGN_OR_RETURN (auto model, internal::LoadModelFromFile (out_file));
176+ LITERT_ENSURE (internal::IsFullyCompiled (*model),
177+ Error (kLiteRtStatusErrorRuntimeFailure ),
178+ " Model is not fully compiled." )
179+
180+ LITERT_ASSIGN_OR_RETURN (auto exec,
181+ NpuCompiledModelExecutor::Create (
182+ *model, npu_inference_options.DispatchDir ()));
183+ const auto & subgraph = *model->Subgraphs ()[0 ];
184+ LITERT_ASSIGN_OR_RETURN (
185+ auto inputs, SimpleBuffer::LikeSignature (subgraph.Inputs ().begin (),
186+ subgraph.Inputs ().end ()));
187+ LITERT_RETURN_IF_ERROR (exec.Run (inputs));
188+ }
167189
168190 return {};
169191}
0 commit comments