@@ -58,23 +58,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
58
58
59
59
void MultiDevSSAGraphBuilder::CreateOpHandleIOs (SSAGraph *result,
60
60
const OpDesc &op,
61
- const platform::Place &p,
62
- const size_t &i) const {
61
+ size_t place_id) const {
62
+ auto p = places_[place_id];
63
63
auto *op_handle = result->ops_ .back ().get ();
64
64
op_handle->SetDeviceContext (p,
65
65
platform::DeviceContextPool::Instance ().Get (p));
66
66
67
- auto var_names = op.InputArgumentNames ();
68
-
69
- for (auto &each_var_name : var_names) {
70
- VarHandle *var = CreateOrGetLatestVarHandle (result, each_var_name, p, i);
67
+ for (auto &each_var_name : op.InputArgumentNames ()) {
68
+ VarHandle *var =
69
+ CreateOrGetLatestVarHandle (result, each_var_name, p, place_id);
71
70
op_handle->AddInput (var);
72
71
}
73
72
74
- var_names = op.OutputArgumentNames ();
75
-
76
- for (auto &each_var_name : var_names) {
77
- CreateOpOutput (result, op_handle, each_var_name, p, i);
73
+ for (auto &each_var_name : op.OutputArgumentNames ()) {
74
+ CreateOpOutput (result, op_handle, each_var_name, p, place_id);
78
75
}
79
76
}
80
77
@@ -84,17 +81,18 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
84
81
return false ;
85
82
}
86
83
87
- auto checker = [&](const std::vector<std::string> opvars,
88
- const std::vector<std::string> sendvars) -> bool {
89
- bool is_dist_train_op = false ;
84
+ /* *
85
+ * Check any of opvars contains `.block` and in sendvars
86
+ */
87
+ auto checker = [](const std::vector<std::string> &opvars,
88
+ const std::vector<std::string> &sendvars) -> bool {
90
89
for (auto &var : opvars) {
91
90
if (var.find (" .block" ) != std::string::npos &&
92
91
std::find (sendvars.begin (), sendvars.end (), var) != sendvars.end ()) {
93
- is_dist_train_op = true ;
94
- break ;
92
+ return true ;
95
93
}
96
94
}
97
- return is_dist_train_op ;
95
+ return false ;
98
96
};
99
97
100
98
if (op.Type () == " split" ) {
@@ -117,13 +115,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
117
115
places_.size ());
118
116
119
117
// Find "send" op first for split is in front of send.
120
- OpDesc *send_op = nullptr ;
121
- for (auto *op : program.Block (0 ).AllOps ()) {
122
- if (op->Type () == " send" ) {
123
- send_op = op;
124
- break ;
125
- }
126
- }
118
+ OpDesc *send_op = GetSendOpDesc (program);
127
119
128
120
bool is_forwarding = true ;
129
121
for (auto *op : program.Block (0 ).AllOps ()) {
@@ -134,6 +126,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
134
126
} else if (IsDistTrainOp (*op, send_op)) {
135
127
CreateComputationalOps (&result, *op, 1 );
136
128
} else if (IsScaleLossOp (*op)) {
129
+ // user can customize loss@grad if skip_scale_loss_
137
130
if (!skip_scale_loss_) {
138
131
CreateScaleLossGradOp (&result);
139
132
}
@@ -142,10 +135,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
142
135
CreateComputationalOps (&result, *op, places_.size ());
143
136
if (!is_forwarding) {
144
137
// Currently, we assume that once gradient is generated, it can be
145
- // broadcast, and each gradient is only broadcast once. But there are no
146
- // other cases, for example, we need to adjust the gradient according to
147
- // the input when we get the gradient, which is not considered at
148
- // present.
138
+ // broadcast, and each gradient is only broadcast once.
149
139
for (auto &og : op->OutputArgumentNames ()) {
150
140
if (IsParameterGradientOnce (og, &og_has_been_broadcast)) {
151
141
InsertNCCLAllReduceOp (&result, og);
@@ -175,6 +165,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
175
165
return std::unique_ptr<SSAGraph>(graph);
176
166
}
177
167
168
+ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc (
169
+ const ProgramDesc &program) const {
170
+ for (auto *op : program.Block (0 ).AllOps ()) {
171
+ if (op->Type () == " send" ) {
172
+ return op;
173
+ }
174
+ }
175
+ return nullptr ;
176
+ }
177
+
178
178
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp (
179
179
SSAGraph *result, const std::string &og) const {
180
180
#ifdef PADDLE_WITH_CUDA
@@ -243,7 +243,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
243
243
auto p = places_[scope_idx];
244
244
auto s = local_scopes_[scope_idx];
245
245
result->ops_ .emplace_back (new ComputationOpHandle (op, s, p));
246
- CreateOpHandleIOs (result, op, p, scope_idx);
246
+ CreateOpHandleIOs (result, op, scope_idx);
247
247
}
248
248
}
249
249
@@ -255,7 +255,7 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
255
255
result->ops_ .emplace_back (new SendOpHandle (op, s, p));
256
256
// Create inputs for output on original place and no ssa output
257
257
// is created for send op.
258
- CreateOpHandleIOs (result, op, p, 0 );
258
+ CreateOpHandleIOs (result, op, 0 );
259
259
}
260
260
261
261
bool MultiDevSSAGraphBuilder::IsScaleLossOp (const OpDesc &op) const {
0 commit comments