Skip to content

Commit df31926

Browse files
committed
small thread-safety fix and doc improvements.
1 parent 59bfa49 commit df31926

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
199199
BuildStrategy::GradientScaleStrategy::kCustomized) {
200200
CreateScaleLossGradOp(&result);
201201
}
202+
// This assumes the backward generating code will ensure IsScaleLossOp
203+
// is true only for the op that scale the final scalar loss.
204+
// It also assumes backward op will always follow the forward op in
205+
// the block.
202206
is_forwarding = false;
203207
} else {
204208
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
@@ -243,6 +247,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
243247
InsertAllReduceOp(&result, g_name);
244248
}
245249
break;
250+
default:
251+
LOG(FATAL) << "Unknown reduce strategy ";
252+
break;
246253
}
247254
}
248255
} catch (boost::bad_get e) {
@@ -261,7 +268,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
261268
}
262269
/*
263270
Dependency graph has been constructed. However, there are still data
264-
harzaeds need to be handled.
271+
hazards need to be handled.
265272
*/
266273
PolishGraphToSupportDataHazards(&result);
267274

@@ -449,6 +456,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
449456
return var;
450457
}
451458

459+
// Find the first occurence of `prev_op_name` and make current `op` depend
460+
// on it.
452461
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
453462
const std::string &prev_op_name) const {
454463
for (auto &prev_op : result->ops_) {
@@ -469,6 +478,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
469478
}
470479
}
471480

481+
// Create RPC related op handles that connects its in ops and out ops.
472482
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
473483
const OpDesc &op) const {
474484
result->ops_.emplace_back(

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
9696
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
9797

9898
if (timeout) {
99+
std::lock_guard<std::mutex> l(exception_mu_);
99100
if (exception_) {
100101
auto exp = *exception_;
101102
exception_.reset();
@@ -199,6 +200,7 @@ void ThreadedSSAGraphExecutor::RunOp(
199200
ready_var_q->Extend(op->Outputs());
200201
VLOG(10) << op << " " << op->Name() << "Signal posted";
201202
} catch (platform::EnforceNotMet ex) {
203+
std::lock_guard<std::mutex> l(exception_mu_);
202204
exception_.reset(new platform::EnforceNotMet(ex));
203205
} catch (...) {
204206
LOG(FATAL) << "Unknown exception catched";

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
5656
std::vector<Scope *> local_scopes_;
5757
std::vector<platform::Place> places_;
5858
platform::DeviceContextPool fetch_ctxs_;
59+
std::mutex exception_mu_;
5960
std::unique_ptr<platform::EnforceNotMet> exception_;
6061
std::atomic<int> running_ops_;
6162

0 commit comments

Comments
 (0)