Skip to content

Commit 2265d09

Browse files
author
chengduo
authored
Fix threaded executor bug (#16508)
* fix threaded executor bug test=develop * change the order of class member test=develop * Fix Travis CI test=develop
1 parent f7f5044 commit 2265d09

File tree

4 files changed

+28
-23
lines changed

4 files changed

+28
-23
lines changed

paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
3131
local_scopes_(local_scopes),
3232
places_(places),
3333
graph_(graph),
34+
fetch_ctxs_(places),
3435
pool_(strategy.num_threads_),
35-
prepare_pool_(1), // add one more thread for generate op_deps
36-
fetch_ctxs_(places) {
36+
// add one more thread for generate op_deps
37+
prepare_pool_(1) {
3738
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
3839
int dep = static_cast<int>(op->NotReadyInputSize());
3940
op_deps_.emplace(op, dep);

paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
#pragma once
1616
#include <ThreadPool.h>
17+
#include <memory>
1718
#include <string>
19+
#include <unordered_map>
1820
#include <vector>
1921
#include "paddle/fluid/framework/blocking_queue.h"
2022
#include "paddle/fluid/framework/details/exception_holder.h"
@@ -37,6 +39,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
3739
const ir::Graph &Graph() const override;
3840

3941
private:
42+
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
43+
// be destroyed first.
4044
ExecutionStrategy strategy_;
4145
std::vector<Scope *> local_scopes_;
4246
std::vector<platform::Place> places_;
@@ -45,21 +49,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
4549
std::unordered_map<OpHandleBase *, int> op_deps_;
4650
std::vector<OpHandleBase *> bootstrap_ops_;
4751

48-
::ThreadPool pool_;
49-
::ThreadPool prepare_pool_;
5052
platform::DeviceContextPool fetch_ctxs_;
5153
std::atomic<int> remaining_;
5254

55+
std::future<
56+
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>>
57+
atomic_op_deps_;
58+
ExceptionHolder exception_;
59+
60+
::ThreadPool pool_;
61+
::ThreadPool prepare_pool_;
62+
5363
void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
5464
OpHandleBase *op,
5565
const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
5666

5767
void PrepareAtomicOpDeps();
58-
59-
std::future<
60-
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>>
61-
atomic_op_deps_;
62-
ExceptionHolder exception_;
6368
};
6469
} // namespace details
6570
} // namespace framework

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
2424
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
2525
const std::vector<platform::Place> &places, ir::Graph *graph)
2626
: graph_(graph),
27-
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
28-
: nullptr),
29-
prepare_pool_(1),
3027
local_scopes_(local_scopes),
3128
places_(places),
3229
fetch_ctxs_(places),
33-
strategy_(strategy) {
30+
strategy_(strategy),
31+
prepare_pool_(1),
32+
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
33+
: nullptr) {
3434
PrepareOpDeps();
3535
CopyOpDeps();
3636
}

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,20 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
6363
details::OpHandleBase *op);
6464

6565
private:
66+
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
67+
// be destroyed first.
6668
ir::Graph *graph_;
67-
std::unique_ptr<::ThreadPool> pool_;
68-
::ThreadPool prepare_pool_;
6969
std::vector<Scope *> local_scopes_;
7070
std::vector<platform::Place> places_;
7171
platform::DeviceContextPool fetch_ctxs_;
7272
ExceptionHolder exception_holder_;
73+
std::unique_ptr<OpDependentData> op_deps_;
74+
std::future<std::unique_ptr<OpDependentData>> op_deps_futures_;
75+
ExecutionStrategy strategy_;
76+
// use std::list because clear(), push_back, and for_each are O(1)
77+
std::list<std::future<void>> run_op_futures_;
78+
::ThreadPool prepare_pool_;
79+
std::unique_ptr<::ThreadPool> pool_;
7380

7481
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
7582
OpHandleBase *op_instance) const;
@@ -88,14 +95,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
8895

8996
void PrepareOpDeps();
9097
void CopyOpDeps();
91-
92-
private:
93-
std::future<std::unique_ptr<OpDependentData>> op_deps_futures_;
94-
95-
ExecutionStrategy strategy_;
96-
std::unique_ptr<OpDependentData> op_deps_;
97-
// use std::list because clear(), push_back, and for_each are O(1)
98-
std::list<std::future<void>> run_op_futures_;
9998
};
10099

101100
} // namespace details

0 commit comments

Comments
 (0)