Skip to content

Commit 39ed148

Browse files
committed
fix multi-thread hang temporary
1 parent 459d4cc commit 39ed148

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

paddle/fluid/inference/analysis/analyzer_tester.cc

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
260260

261261
LOG(INFO) << "===========profile result===========";
262262
if (num_threads == 1) {
263-
std::vector<PaddleTensor> input_slots;
264263
// Prepare inputs.
265-
DataRecord data(FLAGS_infer_ditu_rnn_data, batch_size);
266-
PrepareInputs(&input_slots, &data, batch_size);
267-
268264
Timer timer;
269265
timer.tic();
270266
for (int i = 0; i < num_times; i++) {
@@ -273,21 +269,20 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
273269
print_time(batch_size, num_times, 1, 0, timer.toc() / num_times);
274270
} else {
275271
std::vector<std::thread> threads;
276-
std::vector<PaddleTensor> input_slots;
277-
// Prepare inputs.
278-
PrepareInputs(&input_slots, &data, batch_size);
279-
std::vector<PaddleTensor> outputs;
272+
std::vector<std::unique_ptr<PaddlePredictor>> predictors;
273+
// TODO(yanchunwei): Bug here, the analyzer phase can't be parallelled
274+
// because AttentionLSTM's hard code nodeid will be damanged.
275+
for (int tid = 0; tid < num_threads; ++tid) {
276+
predictors.emplace_back(
277+
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kAnalysis>(
278+
config));
279+
}
280280
for (int tid = 0; tid < num_threads; ++tid) {
281281
threads.emplace_back([&, tid]() {
282-
auto predictor_tid =
283-
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kAnalysis>(
284-
config);
285-
DataRecord data(FLAGS_infer_ditu_rnn_data, batch_size);
286-
287282
Timer timer;
288283
timer.tic();
289284
for (int i = 0; i < num_times; i++) {
290-
predictor_tid->Run(input_slots, &outputs);
285+
predictors[tid]->Run(input_slots, &outputs);
291286
}
292287
print_time(batch_size, num_times, num_threads, tid,
293288
timer.toc() / num_times);
@@ -348,8 +343,9 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
348343
}
349344

350345
TEST(Analyzer, DituRNN) {
351-
TestDituRNNPrediction(false, 1);
352-
TestDituRNNPrediction(true, 1);
346+
// default FLAGS_num_threads = 1
347+
TestDituRNNPrediction(false, FLAGS_num_threads);
348+
TestDituRNNPrediction(true, FLAGS_num_threads);
353349
}
354350

355351
TEST(Analyzer, DituRNN_multi_thread) {

0 commit comments

Comments
 (0)