Skip to content

Commit e9abc66

Browse files
committed
fix pe
1 parent 952fa04 commit e9abc66

File tree

5 files changed

+82
-34
lines changed

5 files changed

+82
-34
lines changed

paddle/fluid/framework/details/computation_op_handle.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ void ComputationOpHandle::RunImpl() {
2929
WaitInputVarGenerated(place_);
3030

3131
this->RunAndRecordEvent([this] {
32+
VLOG(3) << "begin run op type is " << op_->Type();
3233
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
34+
VLOG(3) << "end run op type is " << op_->Type();
3335
});
3436
}
3537

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
15-
#include <fstream>
1615
#include <utility>
1716
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
1817
#include "paddle/fluid/framework/details/computation_op_handle.h"
@@ -79,31 +78,63 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
7978
CreateOpOutput(result, op_handle, each_var_name, p, place_id);
8079
}
8180
}
82-
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
83-
OpDesc *send_op) const {
84-
if (send_op == nullptr) {
81+
82+
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
83+
const ProgramDesc &program) const {
84+
std::vector<std::string> send_vars;
85+
for (auto *op : program.Block(0).AllOps()) {
86+
if (op->Type() == "send_vars" || op->Type() == "send") {
87+
auto op_vars = op->InputArgumentNames();
88+
send_vars.reserve(send_vars.size() +
89+
std::distance(op_vars.begin(), op_vars.end()));
90+
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
91+
}
92+
}
93+
return send_vars;
94+
}
95+
96+
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
97+
const ProgramDesc &program) const {
98+
std::vector<std::string> recv_vars;
99+
for (auto *op : program.Block(0).AllOps()) {
100+
if (op->Type() == "recv" || op->Type() == "send") {
101+
auto op_vars = op->OutputArgumentNames();
102+
recv_vars.reserve(recv_vars.size() +
103+
std::distance(op_vars.begin(), op_vars.end()));
104+
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
105+
}
106+
}
107+
return recv_vars;
108+
}
109+
110+
bool MultiDevSSAGraphBuilder::IsDistTrainOp(
111+
const OpDesc &op, const std::vector<std::string> &send_vars,
112+
const std::vector<std::string> &recv_vars) const {
113+
if (send_vars.size() == 0 || recv_vars.size() == 0) {
85114
return false;
86115
}
87116

88117
/**
89118
* Check any of opvars contains `.block` and in sendvars
90119
*/
91120
auto checker = [](const std::vector<std::string> &opvars,
92-
const std::vector<std::string> &sendvars) -> bool {
121+
const std::vector<std::string> &rpc_vars) -> bool {
93122
for (auto &var : opvars) {
94123
if (var.find(".block") != std::string::npos &&
95-
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
124+
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
96125
return true;
97126
}
98127
}
99128
return false;
100129
};
101130

102-
if (op.Type() == "split" || op.Type() == "split_byref") {
103-
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
131+
if (op.Type() == "split" || op.Type() == "split_byref" ||
132+
op.Type() == "split_selected_rows") {
133+
return checker(op.OutputArgumentNames(), send_vars);
104134
} else if (op.Type() == "concat") {
105-
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
135+
return checker(op.InputArgumentNames(), recv_vars);
106136
}
137+
107138
return false;
108139
}
109140

@@ -132,8 +163,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
132163
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
133164
places_.size());
134165

135-
// Find "send" op first for split is in front of send.
136-
OpDesc *send_op = GetSendOpDesc(program);
166+
// find send/recv vars so that we can place the distributed training
167+
// realted op in the place 0
168+
auto send_vars = FindDistTrainSendVars(program);
169+
auto recv_vars = FindDistTrainRecvVars(program);
137170

138171
size_t cur_device_id = 0;
139172
std::vector<std::unordered_set<std::string>> var_name_on_devices;
@@ -147,8 +180,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
147180
// append rpc op if program is distributed trainer main program.
148181
// always use the first device
149182
CreateRPCOp(&result, *op);
150-
} else if (IsDistTrainOp(*op, send_op)) {
151-
CreateComputationalOps(&result, *op, 1);
183+
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
184+
// CreateComputationalOps(&result, *op, 1);
185+
CreateComputationalOp(&result, *op, 0);
152186
} else if (IsScaleLossOp(*op)) {
153187
// user can customize loss@grad if not use_default_grad_scale_
154188
if (strategy_.gradient_scale_ !=
@@ -213,9 +247,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
213247
AddOutputToLeafOps(&result);
214248

215249
if (VLOG_IS_ON(10)) {
216-
std::string filename = "/tmp/graph";
217-
std::ofstream fout(filename);
218-
PrintGraphviz(*graph, fout);
250+
std::ostringstream sout;
251+
PrintGraphviz(*graph, sout);
252+
VLOG(10) << sout.str();
219253
}
220254

221255
return std::unique_ptr<SSAGraph>(graph);
@@ -274,6 +308,7 @@ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
274308
}
275309
return nullptr;
276310
}
311+
277312
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
278313
SSAGraph *result, const std::string &og) const {
279314
#ifdef PADDLE_WITH_CUDA
@@ -396,14 +431,14 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
396431
return var;
397432
}
398433

399-
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result,
400-
std::string op_name) const {
434+
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
435+
const std::string &prev_op_name) const {
401436
for (auto &prev_op : result->ops_) {
402-
if (prev_op->Name() == op_name) {
437+
if (prev_op->Name() == prev_op_name) {
403438
auto *dep_var = new DummyVarHandle();
404439
prev_op->AddOutput(dep_var);
405440
result->dep_vars_.emplace(dep_var);
406-
result->ops_.back().get()->AddInput(dep_var);
441+
op->AddInput(dep_var);
407442
}
408443
}
409444
}
@@ -412,14 +447,14 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
412447
const OpDesc &op) const {
413448
auto &p = places_[0];
414449
auto *s = local_scopes_[0];
415-
VLOG(3) << "create rpc op: " << op.Type();
416450
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
451+
417452
if (op.Type() == "send_barrier") {
418-
ConnectOp(result, "send_vars");
453+
ConnectOp(result, result->ops_.back().get(), "send_vars");
419454
} else if (op.Type() == "recv") {
420-
ConnectOp(result, "send_barrier");
455+
ConnectOp(result, result->ops_.back().get(), "send_barrier");
421456
} else if (op.Type() == "fetch_barrier") {
422-
ConnectOp(result, "recv");
457+
ConnectOp(result, result->ops_.back().get(), "recv");
423458
} else if (op.Type() == "send" || op.Type() == "send_vars") {
424459
// do nothing
425460
} else {
@@ -429,7 +464,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
429464
}
430465

431466
// FIXME(wuyi): send op always copy from GPU 0
432-
// result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
433467
// Create inputs for output on original place and no ssa output
434468
// is created for send op.
435469
CreateOpHandleIOs(result, op, 0);

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,25 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
6464

6565
bool IsScaleLossOp(const OpDesc &op) const;
6666

67-
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
6867
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
6968

7069
/**
7170
* Is this operator as the end-point operator before/after send operator.
7271
*/
73-
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;
72+
bool IsDistTrainOp(const OpDesc &op,
73+
const std::vector<std::string> &send_vars,
74+
const std::vector<std::string> &recv_vars) const;
75+
76+
std::vector<std::string> FindDistTrainSendVars(
77+
const ProgramDesc &program) const;
78+
79+
std::vector<std::string> FindDistTrainRecvVars(
80+
const ProgramDesc &program) const;
7481

7582
bool IsRPCOp(const OpDesc &op) const;
7683

77-
void ConnectOp(SSAGraph *result, std::string op_name) const;
84+
void ConnectOp(SSAGraph *result, OpHandleBase *op,
85+
const std::string &prev_op_name) const;
7886

7987
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
8088
size_t num_places) const;

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,17 +245,11 @@ bool RPCClient::Proceed() {
245245
}
246246
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep,
247247
const std::string& key) {
248-
VLOG(3) << "this addr: " << this;
249248
std::unique_lock<std::mutex> lock(mutex_);
250249
auto it = channels_.find(key);
251250
if (it != channels_.end()) {
252-
VLOG(3) << "find ep: " << ep;
253251
return it->second;
254252
}
255-
VLOG(3) << "can not find ep: " << ep;
256-
for (auto it = channels_.begin(); it != channels_.end(); ++it) {
257-
VLOG(3) << "ep: " << it->first;
258-
}
259253

260254
grpc::ChannelArguments args;
261255
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,16 @@ def transpile(self,
373373
for i, ep in enumerate(eplist):
374374
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
375375
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
376+
# step4: Concat the parameters splits together after recv.
377+
for varname, splited_var in param_var_mapping.iteritems():
378+
if len(splited_var) <= 1:
379+
continue
380+
orig_param = program.global_block().vars[varname]
381+
program.global_block().append_op(
382+
type="concat",
383+
inputs={"X": splited_var},
384+
outputs={"Out": [orig_param]},
385+
attrs={"axis": 0})
376386

377387
# TODO(Yancey1989): check dist lookup table
378388
if self.has_distributed_lookup_table:

0 commit comments

Comments
 (0)