Skip to content

Commit c21597c

Browse files
authored
fix(PE): use shared_ptr<BlockingQueue> for cross thread communication (#14136)
It seems that the blocking queue might be destroyed early than Run method complete. It might because the Run method throw some unhandled exception. However, it should be shared_ptr when multthread access an resource. So change BlockingQueue as a shared_ptr. test=develop
1 parent 5cc99c4 commit c21597c

File tree

4 files changed

+19
-19
lines changed

4 files changed

+19
-19
lines changed

paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
9292

9393
size_t num_complete = 0;
9494
remaining_ = 0;
95-
BlockingQueue<size_t> complete_q;
95+
auto complete_q = std::make_shared<BlockingQueue<size_t>>();
9696
for (auto op : bootstrap_ops_) {
97-
RunOpAsync(op_deps.get(), op, &complete_q);
97+
RunOpAsync(op_deps.get(), op, complete_q);
9898
}
9999

100100
while (num_complete != op_deps->size()) {
101-
size_t num_comp = complete_q.Pop();
101+
size_t num_comp = complete_q->Pop();
102102
if (num_comp == -1UL) {
103103
int remaining = 0;
104104
while (true) {
@@ -107,7 +107,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
107107
break;
108108
}
109109
for (int i = 0; i < remaining; ++i) {
110-
complete_q.Pop();
110+
complete_q->Pop();
111111
}
112112
}
113113
exception_.ReThrow();
@@ -120,7 +120,8 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
120120
}
121121
void FastThreadedSSAGraphExecutor::RunOpAsync(
122122
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
123-
OpHandleBase *op, BlockingQueue<size_t> *complete_q) {
123+
OpHandleBase *op,
124+
const std::shared_ptr<BlockingQueue<size_t>> &complete_q) {
124125
++remaining_;
125126
this->pool_.enqueue([=] {
126127
OpHandleBase *op_to_run = op;
@@ -144,7 +145,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
144145
if (op_to_run == nullptr) {
145146
op_to_run = pending_op;
146147
} else {
147-
this->RunOpAsync(op_deps, pending_op, complete_q);
148+
RunOpAsync(op_deps, pending_op, complete_q);
148149
}
149150
}
150151
}
@@ -156,8 +157,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
156157
}
157158
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
158159
atomic_op_deps_ = pool_.enqueue([&] {
159-
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps =
160-
new std::unordered_map<OpHandleBase *, std::atomic<int>>;
160+
auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>;
161161
for (auto &pair : op_deps_) {
162162
(*op_deps)[pair.first] = pair.second;
163163
}

paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
5050
std::atomic<int> remaining_;
5151

5252
void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
53-
OpHandleBase *op, BlockingQueue<size_t> *complete_q);
53+
OpHandleBase *op,
54+
const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
5455

5556
void PrepareAtomicOpDeps();
5657

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
3939
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare", nullptr));
4040
std::unordered_map<OpHandleBase *, size_t> pending_ops;
4141
std::unordered_set<VarHandleBase *> pending_vars;
42-
BlockingQueue<VarHandleBase *> ready_vars;
42+
auto ready_vars = std::make_shared<BlockingQueue<VarHandleBase *>>();
4343
std::unordered_set<OpHandleBase *> ready_ops;
4444
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
4545
// streams from multiple GPUs, it's faster to buffer them and schedule
@@ -51,12 +51,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
5151
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
5252
for (auto &name_pair : var_map) {
5353
for (auto &version_pair : name_pair.second) {
54-
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
54+
InsertPendingVar(&pending_vars, ready_vars.get(), version_pair.get());
5555
}
5656
}
5757
}
5858
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
59-
InsertPendingVar(&pending_vars, &ready_vars, var.get());
59+
InsertPendingVar(&pending_vars, ready_vars.get(), var.get());
6060
}
6161

6262
for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) {
@@ -73,12 +73,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
7373
FeedFetchList fetch_data(fetch_tensors.size());
7474

7575
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
76-
&pending_vars, &ready_vars, &fetch_data);
76+
&pending_vars, ready_vars.get(), &fetch_data);
7777

7878
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
7979
for (auto *op : set) {
8080
running_ops_++;
81-
RunOp(&ready_vars, op);
81+
RunOp(ready_vars, op);
8282
}
8383
set.clear();
8484
};
@@ -87,7 +87,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
8787
run_op_futures_.clear();
8888
exception_holder_.Clear();
8989
event.reset(nullptr);
90-
9190
// Step 3. Execution
9291
while (!pending_vars.empty()) {
9392
// 1. Run All Ready ops
@@ -103,7 +102,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
103102

104103
// 2. Find ready variable
105104
bool timeout;
106-
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
105+
auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
107106

108107
if (timeout) {
109108
if (exception_holder_.IsCaught()) {
@@ -133,7 +132,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
133132
}
134133
}
135134
PADDLE_ENFORCE(ready_ops.empty());
136-
137135
// Wait FetchOps.
138136
ClearFetchOp(graph_.get(), &fetch_ops);
139137

@@ -206,7 +204,8 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
206204
}
207205

208206
void ThreadedSSAGraphExecutor::RunOp(
209-
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
207+
const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
208+
details::OpHandleBase *op) {
210209
auto op_run = [ready_var_q, op, this] {
211210
try {
212211
if (VLOG_IS_ON(10)) {

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
5151
~ThreadedSSAGraphExecutor() {}
5252

5353
private:
54-
void RunOp(BlockingQueue<VarHandleBase *> *ready_var_q,
54+
void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
5555
details::OpHandleBase *op);
5656

5757
private:

0 commit comments

Comments
 (0)