Skip to content

Commit 9612c7e

Browse files
committed
Add comments and polish code
1 parent 76c4ae8 commit 9612c7e

File tree

4 files changed

+51
-30
lines changed

4 files changed

+51
-30
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -58,23 +58,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
5858

5959
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
6060
const OpDesc &op,
61-
const platform::Place &p,
62-
const size_t &i) const {
61+
size_t place_id) const {
62+
auto p = places_[place_id];
6363
auto *op_handle = result->ops_.back().get();
6464
op_handle->SetDeviceContext(p,
6565
platform::DeviceContextPool::Instance().Get(p));
6666

67-
auto var_names = op.InputArgumentNames();
68-
69-
for (auto &each_var_name : var_names) {
70-
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
67+
for (auto &each_var_name : op.InputArgumentNames()) {
68+
VarHandle *var =
69+
CreateOrGetLatestVarHandle(result, each_var_name, p, place_id);
7170
op_handle->AddInput(var);
7271
}
7372

74-
var_names = op.OutputArgumentNames();
75-
76-
for (auto &each_var_name : var_names) {
77-
CreateOpOutput(result, op_handle, each_var_name, p, i);
73+
for (auto &each_var_name : op.OutputArgumentNames()) {
74+
CreateOpOutput(result, op_handle, each_var_name, p, place_id);
7875
}
7976
}
8077

@@ -84,17 +81,18 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
8481
return false;
8582
}
8683

87-
auto checker = [&](const std::vector<std::string> opvars,
88-
const std::vector<std::string> sendvars) -> bool {
89-
bool is_dist_train_op = false;
84+
/**
85+
* Check any of opvars contains `.block` and in sendvars
86+
*/
87+
auto checker = [](const std::vector<std::string> &opvars,
88+
const std::vector<std::string> &sendvars) -> bool {
9089
for (auto &var : opvars) {
9190
if (var.find(".block") != std::string::npos &&
9291
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
93-
is_dist_train_op = true;
94-
break;
92+
return true;
9593
}
9694
}
97-
return is_dist_train_op;
95+
return false;
9896
};
9997

10098
if (op.Type() == "split") {
@@ -117,13 +115,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
117115
places_.size());
118116

119117
// Find "send" op first for split is in front of send.
120-
OpDesc *send_op = nullptr;
121-
for (auto *op : program.Block(0).AllOps()) {
122-
if (op->Type() == "send") {
123-
send_op = op;
124-
break;
125-
}
126-
}
118+
OpDesc *send_op = GetSendOpDesc(program);
127119

128120
bool is_forwarding = true;
129121
for (auto *op : program.Block(0).AllOps()) {
@@ -134,6 +126,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
134126
} else if (IsDistTrainOp(*op, send_op)) {
135127
CreateComputationalOps(&result, *op, 1);
136128
} else if (IsScaleLossOp(*op)) {
129+
// user can customize loss@grad if skip_scale_loss_
137130
if (!skip_scale_loss_) {
138131
CreateScaleLossGradOp(&result);
139132
}
@@ -142,10 +135,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
142135
CreateComputationalOps(&result, *op, places_.size());
143136
if (!is_forwarding) {
144137
// Currently, we assume that once gradient is generated, it can be
145-
// broadcast, and each gradient is only broadcast once. But there are no
146-
// other cases, for example, we need to adjust the gradient according to
147-
// the input when we get the gradient, which is not considered at
148-
// present.
138+
// broadcast, and each gradient is only broadcast once.
149139
for (auto &og : op->OutputArgumentNames()) {
150140
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
151141
InsertNCCLAllReduceOp(&result, og);
@@ -175,6 +165,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
175165
return std::unique_ptr<SSAGraph>(graph);
176166
}
177167

168+
OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
169+
const ProgramDesc &program) const {
170+
for (auto *op : program.Block(0).AllOps()) {
171+
if (op->Type() == "send") {
172+
return op;
173+
}
174+
}
175+
return nullptr;
176+
}
177+
178178
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
179179
SSAGraph *result, const std::string &og) const {
180180
#ifdef PADDLE_WITH_CUDA
@@ -243,7 +243,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
243243
auto p = places_[scope_idx];
244244
auto s = local_scopes_[scope_idx];
245245
result->ops_.emplace_back(new ComputationOpHandle(op, s, p));
246-
CreateOpHandleIOs(result, op, p, scope_idx);
246+
CreateOpHandleIOs(result, op, scope_idx);
247247
}
248248
}
249249

@@ -255,7 +255,7 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
255255
result->ops_.emplace_back(new SendOpHandle(op, s, p));
256256
// Create inputs for output on original place and no ssa output
257257
// is created for send op.
258-
CreateOpHandleIOs(result, op, p, 0);
258+
CreateOpHandleIOs(result, op, 0);
259259
}
260260

261261
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
4848

4949
private:
5050
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
51-
const platform::Place &p, const size_t &i) const;
51+
size_t place_id) const;
5252

5353
private:
5454
std::string loss_var_name_;
@@ -65,6 +65,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
6565

6666
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
6767

68+
/**
69+
* Is this operator as the end-point operator before/after send operator.
70+
*/
6871
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;
6972

7073
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
@@ -77,6 +80,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
7780
std::unordered_set<std::string> *og_has_been_broadcast) const;
7881

7982
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
83+
84+
/**
85+
* Get send op in the global block of program.
86+
* nullptr if not found.
87+
*/
88+
OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
8089
};
8190
} // namespace details
8291
} // namespace framework

paddle/fluid/framework/details/ssa_graph.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,22 @@ namespace paddle {
2525
namespace framework {
2626
namespace details {
2727

28+
// A SSA graph used by parallel executor.
2829
struct SSAGraph {
30+
// all variable in each devices.
31+
// The outside vector is the device vector. Each element of this vector is a
32+
// map from variable name to variables. The variables, who have the same name,
33+
// will have a different version. The offset in the
34+
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
2935
std::vector<
3036
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
3137
vars_;
38+
3239
// aux variables to represent dependency. Useful to resolve data hazard.
3340
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
41+
42+
// all operators. NOTE that even we use a vector here, the operators is
43+
// unordered.
3444
std::vector<std::unique_ptr<OpHandleBase>> ops_;
3545
};
3646

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class SSAGraphBuilder {
4848
const platform::Place &place,
4949
size_t place_offset);
5050

51+
// Add an output variable (each_var_name, place, place_offset) to op_handle,
52+
// which belongs to graph
5153
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
5254
const std::string &each_var_name,
5355
const platform::Place &place, size_t place_offset);

0 commit comments

Comments
 (0)