12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
- #include < gflags/gflags.h>
16
- #include < glog/logging.h>
17
- #include < chrono>
18
- #include < fstream>
19
- #include < numeric>
20
- #include < sstream>
21
- #include < string>
22
- #include < vector>
23
- #include " paddle/fluid/inference/api/paddle_inference_api.h"
24
-
25
- DEFINE_int32 (repeat, 1 , " repeat" );
15
+ #include " paddle/fluid/inference/tests/api/tester_helper.h"
26
16
27
17
namespace paddle {
28
18
namespace inference {
@@ -166,16 +156,17 @@ bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs) {
166
156
167
157
std::ifstream fin (FLAGS_infer_data);
168
158
std::string line;
159
+ int sample = 0 ;
169
160
170
- int lineno = 0 ;
161
+ // The unit-test dataset only have 10 samples, each sample have 5 feeds.
171
162
while (std::getline (fin, line)) {
172
163
std::vector<paddle::PaddleTensor> feed_data;
173
- if (!ParseLine (line, &feed_data)) {
174
- LOG (ERROR) << " Parse line[" << lineno << " ] error!" ;
175
- } else {
176
- inputs->push_back (std::move (feed_data));
177
- }
164
+ ParseLine (line, &feed_data);
165
+ inputs->push_back (std::move (feed_data));
166
+ sample++;
167
+ if (!FLAGS_test_all_data && sample == FLAGS_batch_size) break ;
178
168
}
169
+ LOG (INFO) << " number of samples: " << sample;
179
170
180
171
return true ;
181
172
}
@@ -199,19 +190,53 @@ void profile(bool use_mkldnn = false) {
199
190
inputs, &outputs, FLAGS_num_threads);
200
191
}
201
192
193
+ TEST (Analyzer_bert, profile) { profile (); }
194
+ #ifdef PADDLE_WITH_MKLDNN
195
+ TEST (Analyzer_bert, profile_mkldnn) { profile (true ); }
196
+ #endif
197
+
198
+ // Check the fuse status
199
+ TEST (Analyzer_bert, fuse_statis) {
200
+ AnalysisConfig cfg;
201
+ SetConfig (&cfg);
202
+ int num_ops;
203
+ auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
204
+ auto fuse_statis = GetFuseStatis (
205
+ static_cast <AnalysisPredictor *>(predictor.get ()), &num_ops);
206
+ LOG (INFO) << " num_ops: " << num_ops;
207
+ }
208
+
209
+ // Compare result of NativeConfig and AnalysisConfig
202
210
void compare (bool use_mkldnn = false ) {
203
- AnalysisConfig config;
204
- SetConfig (&config);
211
+ AnalysisConfig cfg;
212
+ SetConfig (&cfg);
213
+ if (use_mkldnn) {
214
+ cfg.EnableMKLDNN ();
215
+ }
205
216
206
217
std::vector<std::vector<PaddleTensor>> inputs;
207
218
LoadInputData (&inputs);
208
219
CompareNativeAndAnalysis (
209
- reinterpret_cast <const PaddlePredictor::Config *>(&config ), inputs);
220
+ reinterpret_cast <const PaddlePredictor::Config *>(&cfg ), inputs);
210
221
}
211
222
212
- TEST (Analyzer_bert, profile ) { profile (); }
223
+ TEST (Analyzer_bert, compare ) { compare (); }
213
224
#ifdef PADDLE_WITH_MKLDNN
214
- TEST (Analyzer_bert, profile_mkldnn ) { profile (true ); }
225
+ TEST (Analyzer_bert, compare_mkldnn ) { compare (true /* use_mkldnn */ ); }
215
226
#endif
227
+
228
+ // Compare Deterministic result
229
+ // TODO(luotao): Since each unit-test on CI only have 10 minutes, cancel this to
230
+ // decrease the CI time.
231
+ // TEST(Analyzer_bert, compare_determine) {
232
+ // AnalysisConfig cfg;
233
+ // SetConfig(&cfg);
234
+ //
235
+ // std::vector<std::vector<PaddleTensor>> inputs;
236
+ // LoadInputData(&inputs);
237
+ // CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config
238
+ // *>(&cfg),
239
+ // inputs);
240
+ // }
216
241
} // namespace inference
217
242
} // namespace paddle
0 commit comments