Skip to content

Commit f615ba2

Browse files
committed
update the multi-thread unit-tests
1 parent 35cff5e commit f615ba2

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

paddle/fluid/inference/analysis/analyzer_tester.cc

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
255255
}
256256
}
257257
// Test with a really complicate model.
258-
void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
259-
int num_threads = FLAGS_num_threads) {
258+
void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
259+
int num_threads) {
260260
AnalysisConfig config;
261261
config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__";
262262
config.param_file = FLAGS_infer_ditu_rnn_model + "/param";
@@ -300,7 +300,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
300300
// because AttentionLSTM's hard code nodeid will be damanged.
301301
for (int tid = 0; tid < num_threads; ++tid) {
302302
predictors.emplace_back(
303-
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kAnalysis>(
303+
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
304304
config));
305305
}
306306
for (int tid = 0; tid < num_threads; ++tid) {
@@ -326,7 +326,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
326326
}
327327
LOG(INFO) << "=====================================";
328328

329-
if (use_analysis_and_activate_ir) {
329+
if (use_analysis && activate_ir) {
330330
AnalysisPredictor *analysis_predictor =
331331
dynamic_cast<AnalysisPredictor *>(predictor.get());
332332
auto &fuse_statis = analysis_predictor->analysis_argument()
@@ -353,15 +353,26 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
353353
}
354354
}
355355

356-
// basic unit-test of DituRNN, easy for profiling independently.
357-
TEST(Analyzer, DituRNN) { TestDituRNNPrediction(false, FLAGS_num_threads); }
356+
// Inference with analysis and IR, easy for profiling independently.
357+
TEST(Analyzer, DituRNN) {
358+
TestDituRNNPrediction(true, true, FLAGS_num_threads);
359+
}
358360

359-
// advance unit-test of DituRNN, test use_analysis_and_activate_ir and
360-
// multi-threads.
361-
TEST(Analyzer, DituRNN_multi_thread) {
362-
TestDituRNNPrediction(true, 1);
363-
TestDituRNNPrediction(false, 4);
364-
TestDituRNNPrediction(true, 4);
361+
// Other unit-tests of DituRNN, test different options of use_analysis,
362+
// activate_ir and multi-threads.
363+
TEST(Analyzer, DituRNN_tests) {
364+
int num_threads[2] = {1, 4};
365+
for (auto i : num_threads) {
366+
// Directly infer with the original model.
367+
TestDituRNNPrediction(false, false, i);
368+
// Inference with the original model with the analysis turned on, the
369+
// analysis
370+
// module will transform the program to a data flow graph.
371+
TestDituRNNPrediction(true, false, i);
372+
// Inference with analysis and IR. The IR module will fuse some large
373+
// kernels.
374+
TestDituRNNPrediction(true, true, i);
375+
}
365376
}
366377

367378
} // namespace analysis

0 commit comments

Comments
 (0)