Skip to content

Commit 6debbcd

Browse files
committed
connect fetch barrier and concat op
1 parent 147d54b commit 6debbcd

File tree

5 files changed

+50
-18
lines changed

5 files changed

+50
-18
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
15+
#include <fstream>
1516
#include <utility>
1617
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
1718
#include "paddle/fluid/framework/details/computation_op_handle.h"
@@ -181,8 +182,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
181182
// always use the first device
182183
CreateRPCOp(&result, *op);
183184
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
184-
// CreateComputationalOps(&result, *op, 1);
185-
CreateComputationalOp(&result, *op, 0);
185+
CreateDistTrainOp(&result, *op);
186186
} else if (IsScaleLossOp(*op)) {
187187
// user can customize loss@grad if not use_default_grad_scale_
188188
if (strategy_.gradient_scale_ !=
@@ -247,9 +247,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
247247
AddOutputToLeafOps(&result);
248248

249249
if (VLOG_IS_ON(10)) {
250-
std::ostringstream sout;
251-
PrintGraphviz(*graph, sout);
252-
VLOG(10) << sout.str();
250+
std::ofstream fout("/tmp/graph.dot");
251+
PrintGraphviz(*graph, fout);
253252
}
254253

255254
return std::unique_ptr<SSAGraph>(graph);
@@ -443,6 +442,14 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
443442
}
444443
}
445444

445+
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
446+
const OpDesc &op) const {
447+
CreateComputationalOp(result, op, 0);
448+
if (op.Type() == "concat") {
449+
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
450+
}
451+
}
452+
446453
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
447454
const OpDesc &op) const {
448455
auto &p = places_[0];

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
6565
bool IsScaleLossOp(const OpDesc &op) const;
6666

6767
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
68+
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
6869

6970
/**
7071
* Is this operator as the end-point operator before/after send operator.

paddle/fluid/operators/recv_op.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class RecvOp : public framework::OperatorBase {
3838
auto outs = Outputs("Out");
3939
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
4040
auto client_var_name = Output("RPCClient");
41+
int sync_recv = Attr<int>("sync_recv");
4142

4243
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
4344
auto& ctx = *pool.Get(place);
@@ -54,7 +55,9 @@ class RecvOp : public framework::OperatorBase {
5455
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
5556
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
5657
}
57-
PADDLE_ENFORCE(rpc_client->Wait());
58+
if (sync_recv) {
59+
PADDLE_ENFORCE(rpc_client->Wait());
60+
}
5861
}
5962
};
6063

@@ -75,6 +78,10 @@ This operator can get variables from server side.
7578
"Server endpoints in the order of input "
7679
"variables for mapping")
7780
.SetDefault({});
81+
AddAttr<int>("sync_recv",
82+
"(int, default 0)"
83+
"sync recv or async recv.")
84+
.SetDefault(0);
7885
}
7986
};
8087

paddle/fluid/operators/send_vars_op.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@ class SendVarsOp : public framework::OperatorBase {
5050
"Can not find variable '%s' in the scope.",
5151
client_var_name);
5252
auto* client_var = scope.FindVar(client_var_name);
53-
VLOG(3) << "client var addr: " << client_var;
5453
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
55-
VLOG(3) << "rpc_client addr: " << rpc_client;
5654

5755
for (size_t i = 0; i < ins.size(); i++) {
5856
if (NeedSend(scope, ins[i])) {

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -357,23 +357,42 @@ def transpile(self,
357357
ps_dispatcher.reset()
358358
eplist = ps_dispatcher.dispatch(recv_vars)
359359

360-
program.global_block().append_op(
361-
type="recv",
362-
inputs={},
363-
outputs={"Out": recv_vars,
364-
"RPCClient": rpc_client_var},
365-
attrs={"epmap": eplist})
360+
#program.global_block().append_op(
361+
# type="recv",
362+
# inputs={},
363+
# outputs={"Out": recv_vars,
364+
# "RPCClient": rpc_client_var},
365+
# attrs={"epmap": eplist})
366+
367+
#program.global_block().append_op(
368+
# type="fetch_barrier",
369+
# inputs={},
370+
# outputs={"RPCClient": rpc_client_var},
371+
# attrs={"endpoints": pserver_endpoints})
372+
373+
for i, ep in enumerate(eplist):
374+
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
375+
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
376+
# step4: Concat the parameters splits together after recv.
377+
for varname, splited_var in param_var_mapping.iteritems():
378+
eps = []
379+
for var in splited_var:
380+
index = [v.name for v in recv_vars].index(var.name)
381+
eps.append(eplist[index])
382+
383+
program.global_block().append_op(
384+
type="recv",
385+
inputs={},
386+
outputs={"Out": splited_var,
387+
"RPCClient": rpc_client_var},
388+
attrs={"epmap": eps})
366389

367390
program.global_block().append_op(
368391
type="fetch_barrier",
369392
inputs={},
370393
outputs={"RPCClient": rpc_client_var},
371394
attrs={"endpoints": pserver_endpoints})
372395

373-
for i, ep in enumerate(eplist):
374-
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
375-
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
376-
# step4: Concat the parameters splits together after recv.
377396
for varname, splited_var in param_var_mapping.iteritems():
378397
if len(splited_var) <= 1:
379398
continue

0 commit comments

Comments
 (0)