Skip to content

Commit 0ee6fed

Browse files
authored
Refine dist rpc deps (#12899)
* refine dist train RPC deps * clean up * clean up * fix ut * remove input for fetch_barrier * follow comments
1 parent 3a0b6f9 commit 0ee6fed

File tree

10 files changed

+101
-106
lines changed

10 files changed

+101
-106
lines changed

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -754,17 +754,26 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
754754
node->Op()->Type());
755755

756756
CreateComputationalOp(result, node, op_dev_id);
757-
if (node->Op()->Type() == "concat") {
758-
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
759-
"fetch_barrier");
757+
}
758+
759+
void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
760+
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
761+
for (ir::Node *input : node->inputs) {
762+
VarHandle *var = nullptr;
763+
for (int place_offset = 0; place_offset < num_places; ++place_offset) {
764+
auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset];
765+
auto &var_holder = var_holders[input->Name()];
766+
if (!var_holder.empty()) {
767+
var = var_holder.rbegin()->get();
768+
op_handle->AddInput(var);
769+
}
770+
}
760771
}
761772
}
762773

763774
// Create RPC related op handles that connects its in ops and out ops.
764775
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
765776
ir::Node *node) const {
766-
// FIXME(typhoonzero): Cleanup this deps for both sync mode and async mode
767-
// put them into transpiler.
768777
int op_dev_id = -1;
769778
if (node->Op()->Type() == "send") {
770779
// TODO(paddle-dev): getting the first var is not safe.
@@ -799,8 +808,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
799808
}
800809
auto recv_param_grad = boost::get<std::vector<std::string>>(
801810
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
802-
// FIXME(typhoonzero): assume each recv op output one param
803-
// Use the same place as send.
804811
if (recv_param_grad.size() == 2U) {
805812
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]);
806813
VLOG(10) << "recv param " << recv_param_grad[0]
@@ -814,34 +821,44 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
814821
.emplace(varname, op_dev_id);
815822
}
816823
} else {
817-
// send_barrier and fetch_barrier op can be scheduled on device 0
824+
// send_barrier, fetch_barrier will run on place 0;
818825
op_dev_id = 0;
819826
}
820827

821828
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
822829
node->Op()->Type());
823-
824830
result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
825831
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
826832
node->Op()->Type(), places_[op_dev_id]));
827833

828-
// TODO(panyx0718): This might not be needed anymore.
829-
if (node->Op()->Type() == "send_barrier") {
830-
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "send");
831-
} else if (node->Op()->Type() == "recv") {
832-
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
833-
"send_barrier");
834-
} else if (node->Op()->Type() == "fetch_barrier") {
835-
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "recv");
836-
} else if (node->Op()->Type() == "send") {
837-
// do nothing
834+
if (node->Op()->Type() == "send") {
835+
CreateOpHandleIOs(result, node, op_dev_id);
838836
} else {
839-
PADDLE_THROW(
840-
"rpc op should be in ["
841-
"send, send_barrier. recv, fetch_barrier]");
842-
}
837+
// send_barrier, recv, fetch_barrier's inputs are deps var, get them from
838+
// all places
839+
auto p = places_[op_dev_id];
840+
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
841+
op_handle->SetDeviceContext(p,
842+
platform::DeviceContextPool::Instance().Get(p));
843843

844-
CreateOpHandleIOs(result, node, op_dev_id);
844+
SetOpInputsAllPlaces(result, node, places_.size());
845+
for (ir::Node *output : node->outputs) {
846+
int outvar_dev_id = op_dev_id;
847+
if (node->Op()->Type() == "fetch_barrier") {
848+
outvar_dev_id = GetVarDeviceID(*result, output->Name());
849+
PADDLE_ENFORCE_NE(outvar_dev_id, -1);
850+
}
851+
p = places_[outvar_dev_id];
852+
ir::Node *new_node = nullptr;
853+
if (output->Var()) {
854+
new_node = result->CreateVarNode(output->Var());
855+
} else {
856+
new_node =
857+
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
858+
}
859+
CreateOpOutput(result, op_handle, new_node, p, outvar_dev_id);
860+
}
861+
}
845862
}
846863

847864
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {

paddle/fluid/framework/ir/graph.cc

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -132,63 +132,6 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
132132
}
133133
}
134134

135-
std::vector<ir::Node *> send_ops;
136-
ir::Node *send_bar = nullptr;
137-
std::vector<ir::Node *> recv_ops;
138-
ir::Node *fetch_bar = nullptr;
139-
for (ir::Node *node : Nodes()) {
140-
if (node->Name() == "send") {
141-
send_ops.push_back(node);
142-
} else if (node->Name() == "send_barrier") {
143-
PADDLE_ENFORCE(!send_bar, "only has one send barrier");
144-
send_bar = node;
145-
} else if (node->Name() == "recv") {
146-
recv_ops.push_back(node);
147-
} else if (node->Name() == "fetch_barrier") {
148-
PADDLE_ENFORCE(!fetch_bar, "only has one fetch barrier");
149-
fetch_bar = node;
150-
}
151-
}
152-
if (send_bar) {
153-
for (ir::Node *send : send_ops) {
154-
ir::Node *dep_var = CreateControlDepVar();
155-
send->outputs.push_back(dep_var);
156-
dep_var->inputs.push_back(send);
157-
send_bar->inputs.push_back(dep_var);
158-
dep_var->outputs.push_back(send_bar);
159-
}
160-
for (ir::Node *recv : recv_ops) {
161-
ir::Node *dep_var = CreateControlDepVar();
162-
recv->inputs.push_back(dep_var);
163-
dep_var->outputs.push_back(recv);
164-
send_bar->outputs.push_back(dep_var);
165-
dep_var->inputs.push_back(send_bar);
166-
}
167-
}
168-
if (fetch_bar) {
169-
for (ir::Node *recv : recv_ops) {
170-
ir::Node *dep_var = CreateControlDepVar();
171-
recv->outputs.push_back(dep_var);
172-
dep_var->inputs.push_back(recv);
173-
fetch_bar->inputs.push_back(dep_var);
174-
dep_var->outputs.push_back(fetch_bar);
175-
}
176-
}
177-
178-
std::vector<std::string> send_vars = FindDistTrainSendVars(send_ops);
179-
std::vector<std::string> recv_vars = FindDistTrainRecvVars(recv_ops);
180-
for (ir::Node *node : Nodes()) {
181-
if (IsDistTrainOp(node, send_vars, recv_vars)) {
182-
if (fetch_bar && node->Name() == "concat") {
183-
ir::Node *dep_var = CreateControlDepVar();
184-
fetch_bar->outputs.push_back(dep_var);
185-
dep_var->inputs.push_back(fetch_bar);
186-
node->inputs.push_back(dep_var);
187-
dep_var->outputs.push_back(node);
188-
}
189-
}
190-
}
191-
192135
/**
193136
* We should handle write after read(WAR) and write after write(WAW) here.
194137
* Because some of the operators of the program can be executed parallelly.

paddle/fluid/operators/fetch_barrier_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class FetchBarrierOp : public framework::OperatorBase {
5252
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
5353
public:
5454
void Make() {
55+
AddOutput("Out", "(Any) Dummy outputs, used for control dependency")
56+
.AsDuplicable();
5557
AddComment(R"DOC(
5658
SendBarrier operator
5759

paddle/fluid/operators/send_barrier_op.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ class SendBarrierOp : public framework::OperatorBase {
5656
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
5757
public:
5858
void Make() {
59+
AddInput("X", "(Any) Dummy inputs, used for control dependency")
60+
.AsDuplicable();
61+
AddOutput("Out", "(Any) Dummy outputs, used for control dependency")
62+
.AsDuplicable();
5963
AddComment(R"DOC(
6064
SendBarrier operator
6165

python/paddle/fluid/layers/io.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,11 @@ def Send(endpoints, send_vars, dummy_output=None, sync=True):
246246
rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC
247247
})
248248
if sync:
249-
helper.append_op(type="send_barrier", attrs={"endpoints": endpoints})
249+
helper.append_op(
250+
type="send_barrier",
251+
inputs={"X": dummy_output},
252+
outputs={"Out": []},
253+
attrs={"endpoints": endpoints})
250254

251255

252256
def Recv(endpoints, get_vars, dummy_input=None, sync=True):
@@ -282,7 +286,10 @@ def Recv(endpoints, get_vars, dummy_input=None, sync=True):
282286
attrs={"endpoints": endpoints,
283287
"epmap": epmap})
284288
if sync:
285-
helper.append_op(type="fetch_barrier", attrs={"endpoints": endpoints})
289+
helper.append_op(
290+
type="fetch_barrier",
291+
outputs={"Out": get_vars},
292+
attrs={"endpoints": endpoints})
286293
return get_vars
287294

288295

python/paddle/fluid/tests/unittests/dist_se_resnext.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def net(self, input, class_dim=1000):
134134
size=class_dim,
135135
act='softmax',
136136
param_attr=fluid.ParamAttr(
137-
initializer=fluid.initializer.Constant(value=0.2)))
137+
initializer=fluid.initializer.Constant(value=0.05)))
138138
return out
139139

140140
def shortcut(self, input, ch_out, stride):
@@ -184,21 +184,27 @@ def conv_bn_layer(self,
184184
act=None,
185185
# avoid pserver CPU init differs from GPU
186186
param_attr=fluid.ParamAttr(
187-
initializer=fluid.initializer.Constant(value=0.2)),
187+
initializer=fluid.initializer.Constant(value=0.05)),
188188
bias_attr=False)
189189
return fluid.layers.batch_norm(input=conv, act=act)
190190

191191
def squeeze_excitation(self, input, num_channels, reduction_ratio):
192192
pool = fluid.layers.pool2d(
193193
input=input, pool_size=0, pool_type='avg', global_pooling=True)
194194
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
195-
squeeze = fluid.layers.fc(input=pool,
196-
size=num_channels // reduction_ratio,
197-
act='relu')
195+
squeeze = fluid.layers.fc(
196+
input=pool,
197+
size=num_channels // reduction_ratio,
198+
param_attr=fluid.ParamAttr(
199+
initializer=fluid.initializer.Constant(value=0.05)),
200+
act='relu')
198201
stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
199-
excitation = fluid.layers.fc(input=squeeze,
200-
size=num_channels,
201-
act='sigmoid')
202+
excitation = fluid.layers.fc(
203+
input=squeeze,
204+
size=num_channels,
205+
param_attr=fluid.ParamAttr(
206+
initializer=fluid.initializer.Constant(value=0.05)),
207+
act='sigmoid')
202208
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
203209
return scale
204210

python/paddle/fluid/tests/unittests/dist_word2vec.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,28 +49,32 @@ def __network__(words):
4949
dtype='float32',
5050
is_sparse=IS_SPARSE,
5151
param_attr=fluid.ParamAttr(
52-
name='shared_w', initializer=fluid.initializer.Constant()))
52+
name='shared_w',
53+
initializer=fluid.initializer.Constant(value=0.1)))
5354
embed_second = fluid.layers.embedding(
5455
input=words[1],
5556
size=[dict_size, EMBED_SIZE],
5657
dtype='float32',
5758
is_sparse=IS_SPARSE,
5859
param_attr=fluid.ParamAttr(
59-
name='shared_w', initializer=fluid.initializer.Constant()))
60+
name='shared_w',
61+
initializer=fluid.initializer.Constant(value=0.1)))
6062
embed_third = fluid.layers.embedding(
6163
input=words[2],
6264
size=[dict_size, EMBED_SIZE],
6365
dtype='float32',
6466
is_sparse=IS_SPARSE,
6567
param_attr=fluid.ParamAttr(
66-
name='shared_w', initializer=fluid.initializer.Constant()))
68+
name='shared_w',
69+
initializer=fluid.initializer.Constant(value=0.1)))
6770
embed_forth = fluid.layers.embedding(
6871
input=words[3],
6972
size=[dict_size, EMBED_SIZE],
7073
dtype='float32',
7174
is_sparse=IS_SPARSE,
7275
param_attr=fluid.ParamAttr(
73-
name='shared_w', initializer=fluid.initializer.Constant()))
76+
name='shared_w',
77+
initializer=fluid.initializer.Constant(value=0.1)))
7478

7579
concat_embed = fluid.layers.concat(
7680
input=[embed_first, embed_second, embed_third, embed_forth],
@@ -80,13 +84,13 @@ def __network__(words):
8084
size=HIDDEN_SIZE,
8185
act='sigmoid',
8286
param_attr=fluid.ParamAttr(
83-
initializer=fluid.initializer.Constant()))
87+
initializer=fluid.initializer.Constant(value=0.1)))
8488
predict_word = fluid.layers.fc(
8589
input=hidden1,
8690
size=dict_size,
8791
act='softmax',
8892
param_attr=fluid.ParamAttr(
89-
initializer=fluid.initializer.Constant()))
93+
initializer=fluid.initializer.Constant(value=0.1)))
9094
cost = fluid.layers.cross_entropy(
9195
input=predict_word, label=words[4])
9296
avg_cost = fluid.layers.mean(cost)

python/paddle/fluid/tests/unittests/test_dist_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def init_client(self, place, port):
100100
main.global_block().append_op(
101101
type="fetch_barrier",
102102
inputs={},
103-
outputs={},
103+
outputs={"Out": []},
104104
attrs={
105105
"endpoints": ["127.0.0.1:{0}".format(port)],
106106
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE

python/paddle/fluid/tests/unittests/test_dist_word2vec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def _setup_config(self):
2222
self._sync_mode = True
2323

2424
def test_se_resnext(self):
25-
self.check_with_place("dist_word2vec.py", delta=1e-7)
25+
self.check_with_place("dist_word2vec.py", delta=1e-4)
2626

2727

2828
class TestDistSeResneXt2x2Async(TestDistBase):

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,13 @@ def transpile(self,
283283
send_vars.append(var)
284284

285285
if self.sync_mode:
286+
send_barrier_out = program.global_block().create_var(
287+
name=framework.generate_control_dev_var_name())
288+
input_deps = grad_name_to_send_dummy_out.values()
286289
program.global_block().append_op(
287290
type="send_barrier",
288-
inputs={},
289-
outputs={},
291+
inputs={"X": input_deps},
292+
outputs={"Out": send_barrier_out},
290293
attrs={
291294
"endpoints": pserver_endpoints,
292295
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
@@ -304,16 +307,22 @@ def transpile(self,
304307
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
305308

306309
# step4: Concat the parameters splits together after recv.
310+
all_recv_outputs = []
307311
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
308312
eps = []
309313
for var in splited_var:
310314
index = [v.name for v in recv_vars].index(var.name)
311315
eps.append(eplist[index])
312-
grad_send_dummy_out = grad_name_to_send_dummy_out[
313-
self.param_name_to_grad_name[param_varname]]
316+
if self.sync_mode:
317+
recv_dep_in = send_barrier_out
318+
else:
319+
# connect deps to send op in async mode
320+
recv_dep_in = grad_name_to_send_dummy_out[
321+
self.param_name_to_grad_name[param_varname]]
322+
all_recv_outputs.extend(splited_var)
314323
program.global_block().append_op(
315324
type="recv",
316-
inputs={"X": [grad_send_dummy_out]},
325+
inputs={"X": [recv_dep_in]},
317326
outputs={"Out": splited_var},
318327
attrs={
319328
"epmap": eps,
@@ -326,10 +335,11 @@ def transpile(self,
326335
})
327336

328337
if self.sync_mode:
338+
# form a WAW dependency
329339
program.global_block().append_op(
330340
type="fetch_barrier",
331341
inputs={},
332-
outputs={},
342+
outputs={"Out": all_recv_outputs},
333343
attrs={
334344
"endpoints": pserver_endpoints,
335345
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
@@ -413,10 +423,12 @@ def _get_trainer_startup_program(self,
413423
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
414424
})
415425

426+
fetch_barrier_out = startup_program.global_block().create_var(
427+
name=framework.generate_control_dev_var_name())
416428
startup_program.global_block().append_op(
417429
type="fetch_barrier",
418430
inputs={},
419-
outputs={},
431+
outputs={"Out": fetch_barrier_out},
420432
attrs={
421433
"endpoints": self.pserver_endpoints,
422434
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE

0 commit comments

Comments
 (0)