@@ -260,11 +260,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
260
260
261
261
LOG (INFO) << " ===========profile result===========" ;
262
262
if (num_threads == 1 ) {
263
- std::vector<PaddleTensor> input_slots;
264
263
// Prepare inputs.
265
- DataRecord data (FLAGS_infer_ditu_rnn_data, batch_size);
266
- PrepareInputs (&input_slots, &data, batch_size);
267
-
268
264
Timer timer;
269
265
timer.tic ();
270
266
for (int i = 0 ; i < num_times; i++) {
@@ -273,21 +269,20 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
273
269
print_time (batch_size, num_times, 1 , 0 , timer.toc () / num_times);
274
270
} else {
275
271
std::vector<std::thread> threads;
276
- std::vector<PaddleTensor> input_slots;
277
- // Prepare inputs.
278
- PrepareInputs (&input_slots, &data, batch_size);
279
- std::vector<PaddleTensor> outputs;
272
+ std::vector<std::unique_ptr<PaddlePredictor>> predictors;
273
+ // TODO(yanchunwei): Bug here, the analyzer phase can't be parallelled
274
+ // because AttentionLSTM's hard code nodeid will be damanged.
275
+ for (int tid = 0 ; tid < num_threads; ++tid) {
276
+ predictors.emplace_back (
277
+ CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kAnalysis >(
278
+ config));
279
+ }
280
280
for (int tid = 0 ; tid < num_threads; ++tid) {
281
281
threads.emplace_back ([&, tid]() {
282
- auto predictor_tid =
283
- CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kAnalysis >(
284
- config);
285
- DataRecord data (FLAGS_infer_ditu_rnn_data, batch_size);
286
-
287
282
Timer timer;
288
283
timer.tic ();
289
284
for (int i = 0 ; i < num_times; i++) {
290
- predictor_tid ->Run (input_slots, &outputs);
285
+ predictors[tid] ->Run (input_slots, &outputs);
291
286
}
292
287
print_time (batch_size, num_times, num_threads, tid,
293
288
timer.toc () / num_times);
@@ -348,8 +343,9 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
348
343
}
349
344
350
345
TEST (Analyzer, DituRNN) {
351
- TestDituRNNPrediction (false , 1 );
352
- TestDituRNNPrediction (true , 1 );
346
+ // default FLAGS_num_threads = 1
347
+ TestDituRNNPrediction (false , FLAGS_num_threads);
348
+ TestDituRNNPrediction (true , FLAGS_num_threads);
353
349
}
354
350
355
351
TEST (Analyzer, DituRNN_multi_thread) {
0 commit comments