@@ -39,7 +39,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
39
39
new platform::RecordEvent (" ThreadedSSAGraphExecutorPrepare" , nullptr ));
40
40
std::unordered_map<OpHandleBase *, size_t > pending_ops;
41
41
std::unordered_set<VarHandleBase *> pending_vars;
42
- BlockingQueue<VarHandleBase *> ready_vars ;
42
+ auto ready_vars = std::make_shared< BlockingQueue<VarHandleBase *>>() ;
43
43
std::unordered_set<OpHandleBase *> ready_ops;
44
44
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
45
45
// streams from multiple GPUs, it's faster to buffer them and schedule
@@ -51,12 +51,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
51
51
for (auto &var_map : graph_->Get <details::GraphVars>(details::kGraphVars )) {
52
52
for (auto &name_pair : var_map) {
53
53
for (auto &version_pair : name_pair.second ) {
54
- InsertPendingVar (&pending_vars, & ready_vars, version_pair.get ());
54
+ InsertPendingVar (&pending_vars, ready_vars. get () , version_pair.get ());
55
55
}
56
56
}
57
57
}
58
58
for (auto &var : graph_->Get <details::GraphDepVars>(details::kGraphDepVars )) {
59
- InsertPendingVar (&pending_vars, & ready_vars, var.get ());
59
+ InsertPendingVar (&pending_vars, ready_vars. get () , var.get ());
60
60
}
61
61
62
62
for (auto &op : graph_->Get <details::GraphOps>(details::kGraphOps )) {
@@ -73,12 +73,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
73
73
FeedFetchList fetch_data (fetch_tensors.size ());
74
74
75
75
InsertFetchOps (fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
76
- &pending_vars, & ready_vars, &fetch_data);
76
+ &pending_vars, ready_vars. get () , &fetch_data);
77
77
78
78
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
79
79
for (auto *op : set) {
80
80
running_ops_++;
81
- RunOp (& ready_vars, op);
81
+ RunOp (ready_vars, op);
82
82
}
83
83
set.clear ();
84
84
};
@@ -87,7 +87,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
87
87
run_op_futures_.clear ();
88
88
exception_holder_.Clear ();
89
89
event.reset (nullptr );
90
-
91
90
// Step 3. Execution
92
91
while (!pending_vars.empty ()) {
93
92
// 1. Run All Ready ops
@@ -103,7 +102,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
103
102
104
103
// 2. Find ready variable
105
104
bool timeout;
106
- auto cur_ready_vars = ready_vars. PopAll (1 , &timeout);
105
+ auto cur_ready_vars = ready_vars-> PopAll (1 , &timeout);
107
106
108
107
if (timeout) {
109
108
if (exception_holder_.IsCaught ()) {
@@ -133,7 +132,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
133
132
}
134
133
}
135
134
PADDLE_ENFORCE (ready_ops.empty ());
136
-
137
135
// Wait FetchOps.
138
136
ClearFetchOp (graph_.get (), &fetch_ops);
139
137
@@ -206,7 +204,8 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
206
204
}
207
205
208
206
void ThreadedSSAGraphExecutor::RunOp (
209
- BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
207
+ const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
208
+ details::OpHandleBase *op) {
210
209
auto op_run = [ready_var_q, op, this ] {
211
210
try {
212
211
if (VLOG_IS_ON (10 )) {
0 commit comments