@@ -255,8 +255,8 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
255
255
}
256
256
}
257
257
// Test with a really complicate model.
258
- void TestDituRNNPrediction (bool use_analysis_and_activate_ir = false ,
259
- int num_threads = FLAGS_num_threads ) {
258
+ void TestDituRNNPrediction (bool use_analysis, bool activate_ir ,
259
+ int num_threads) {
260
260
AnalysisConfig config;
261
261
config.prog_file = FLAGS_infer_ditu_rnn_model + " /__model__" ;
262
262
config.param_file = FLAGS_infer_ditu_rnn_model + " /param" ;
@@ -300,7 +300,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
300
300
// because AttentionLSTM's hard code nodeid will be damanged.
301
301
for (int tid = 0 ; tid < num_threads; ++tid) {
302
302
predictors.emplace_back (
303
- CreatePaddlePredictor<NativeConfig , PaddleEngineKind::kAnalysis >(
303
+ CreatePaddlePredictor<AnalysisConfig , PaddleEngineKind::kAnalysis >(
304
304
config));
305
305
}
306
306
for (int tid = 0 ; tid < num_threads; ++tid) {
@@ -326,7 +326,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
326
326
}
327
327
LOG (INFO) << " =====================================" ;
328
328
329
- if (use_analysis_and_activate_ir ) {
329
+ if (use_analysis && activate_ir ) {
330
330
AnalysisPredictor *analysis_predictor =
331
331
dynamic_cast <AnalysisPredictor *>(predictor.get ());
332
332
auto &fuse_statis = analysis_predictor->analysis_argument ()
@@ -353,15 +353,26 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
353
353
}
354
354
}
355
355
356
- // basic unit-test of DituRNN, easy for profiling independently.
357
- TEST (Analyzer, DituRNN) { TestDituRNNPrediction (false , FLAGS_num_threads); }
356
+ // Inference with analysis and IR, easy for profiling independently.
357
+ TEST (Analyzer, DituRNN) {
358
+ TestDituRNNPrediction (true , true , FLAGS_num_threads);
359
+ }
358
360
359
- // advance unit-test of DituRNN, test use_analysis_and_activate_ir and
360
- // multi-threads.
361
- TEST (Analyzer, DituRNN_multi_thread) {
362
- TestDituRNNPrediction (true , 1 );
363
- TestDituRNNPrediction (false , 4 );
364
- TestDituRNNPrediction (true , 4 );
361
+ // Other unit-tests of DituRNN, test different options of use_analysis,
362
+ // activate_ir and multi-threads.
363
+ TEST (Analyzer, DituRNN_tests) {
364
+ int num_threads[2 ] = {1 , 4 };
365
+ for (auto i : num_threads) {
366
+ // Directly infer with the original model.
367
+ TestDituRNNPrediction (false , false , i);
368
+ // Inference with the original model with the analysis turned on, the
369
+ // analysis
370
+ // module will transform the program to a data flow graph.
371
+ TestDituRNNPrediction (true , false , i);
372
+ // Inference with analysis and IR. The IR module will fuse some large
373
+ // kernels.
374
+ TestDituRNNPrediction (true , true , i);
375
+ }
365
376
}
366
377
367
378
} // namespace analysis
0 commit comments