Skip to content

Commit 49080ac

Browse files
authored
Merge pull request #11621 from panyx0718/bug_fx
Merge pull request #11608 from panyx0718/doc
2 parents c660b47 + 4b446ac commit 49080ac

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
200200
BuildStrategy::GradientScaleStrategy::kCustomized) {
201201
CreateScaleLossGradOp(&result);
202202
}
203+
// This assumes the backward generating code will ensure IsScaleLossOp
204+
// is true only for the op that scale the final scalar loss.
205+
// It also assumes backward op will always follow the forward op in
206+
// the block.
203207
is_forwarding = false;
204208
} else {
205209
int op_dev_id = GetOpDeviceID(*op);
@@ -244,6 +248,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
244248
InsertAllReduceOp(&result, g_name);
245249
}
246250
break;
251+
default:
252+
LOG(FATAL) << "Unknown reduce strategy ";
253+
break;
247254
}
248255
}
249256
} catch (boost::bad_get e) {
@@ -262,7 +269,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
262269
}
263270
/*
264271
Dependency graph has been constructed. However, there are still data
265-
harzaeds need to be handled.
272+
hazards need to be handled.
266273
*/
267274
PolishGraphToSupportDataHazards(&result);
268275

@@ -447,6 +454,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
447454
return var;
448455
}
449456

457+
// Find the first occurence of `prev_op_name` and make current `op` depend
458+
// on it.
450459
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
451460
const std::string &prev_op_name) const {
452461
for (auto &prev_op : result->ops_) {
@@ -490,6 +499,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
490499
}
491500
}
492501

502+
// Create RPC related op handles that connects its in ops and out ops.
493503
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
494504
const OpDesc &op) const {
495505
int op_dev_id = -1;

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
9696
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
9797

9898
if (timeout) {
99+
std::lock_guard<std::mutex> l(exception_mu_);
99100
if (exception_) {
100101
auto exp = *exception_;
101102
exception_.reset();
@@ -199,6 +200,7 @@ void ThreadedSSAGraphExecutor::RunOp(
199200
ready_var_q->Extend(op->Outputs());
200201
VLOG(10) << op << " " << op->Name() << "Signal posted";
201202
} catch (platform::EnforceNotMet ex) {
203+
std::lock_guard<std::mutex> l(exception_mu_);
202204
exception_.reset(new platform::EnforceNotMet(ex));
203205
} catch (...) {
204206
LOG(FATAL) << "Unknown exception catched";

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
5656
std::vector<Scope *> local_scopes_;
5757
std::vector<platform::Place> places_;
5858
platform::DeviceContextPool fetch_ctxs_;
59+
std::mutex exception_mu_;
5960
std::unique_ptr<platform::EnforceNotMet> exception_;
6061
std::atomic<int> running_ops_;
6162

0 commit comments

Comments
 (0)