Skip to content

Commit 9a43c9a

Browse files
committed
Add dependency to send recv (#12760)
Add dependency to send recv
1 parent 482d297 commit 9a43c9a

File tree

5 files changed

+58
-32
lines changed

5 files changed

+58
-32
lines changed

paddle/fluid/operators/recv_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class RecvOp : public framework::OperatorBase {
5757
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
5858
public:
5959
void Make() {
60+
AddInput("X", "(Any) Dummy inputs, used for control dependency")
61+
.AsDuplicable();
6062
AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable();
6163
AddComment(R"DOC(
6264
Recv operator

paddle/fluid/operators/send_barrier_op.cc

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,19 @@ class SendBarrierOp : public framework::OperatorBase {
3737
void RunImpl(const framework::Scope& scope,
3838
const platform::Place& place) const override {
3939
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
40-
bool sync_mode = Attr<bool>("sync_mode");
4140

4241
distributed::RPCClient* rpc_client =
4342
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
4443

45-
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
44+
VLOG(3) << "SendBarrierOp sync";
4645

4746
// need to wait before sending send_barrier message
4847
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
49-
if (sync_mode) {
50-
for (auto& ep : eps) {
51-
VLOG(3) << "send barrier, ep: " << ep;
52-
rpc_client->AsyncSendBatchBarrier(ep);
53-
}
54-
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
48+
for (auto& ep : eps) {
49+
VLOG(3) << "send barrier, ep: " << ep;
50+
rpc_client->AsyncSendBatchBarrier(ep);
5551
}
52+
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
5653
}
5754
};
5855

@@ -70,7 +67,6 @@ the Parameter Server would knew all variables have been sent.
7067
"(string vector, default 127.0.0.1:6164)"
7168
"Server endpoints to send variables to.")
7269
.SetDefault({"127.0.0.1:6164"});
73-
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
7470
}
7571
};
7672

paddle/fluid/operators/send_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
6666
void Make() {
6767
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
6868
.AsDuplicable();
69+
AddOutput("Out", "(Any) Dummy outputs, used for control dependency")
70+
.AsDuplicable();
6971
AddComment(R"DOC(
7072
Send operator
7173

python/paddle/fluid/layers/io.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .. import core
2525
from ..executor import global_scope
2626
from ..framework import convert_np_dtype_to_dtype_, default_main_program, \
27-
default_startup_program, program_guard, Program
27+
default_startup_program, program_guard, Program, Variable
2828
from ..layer_helper import LayerHelper
2929
from ..unique_name import generate as unique_name
3030

@@ -209,7 +209,7 @@ def complete_op(self):
209209
})
210210

211211

212-
def Send(endpoints, send_vars, sync=True):
212+
def Send(endpoints, send_vars, dummy_output=None, sync=True):
213213
"""
214214
Send variables to the server side, and get vars from server
215215
side when server have finished running server side program.
@@ -223,6 +223,13 @@ def Send(endpoints, send_vars, sync=True):
223223
"""
224224
assert (type(send_vars) == list)
225225

226+
if dummy_output is None:
227+
dummy_output = []
228+
elif isinstance(dummy_output, Variable):
229+
dummy_output = [dummy_output]
230+
231+
assert (type(dummy_output) == list)
232+
226233
epmap = endpoints.split(",")
227234
endpoints = list(set(epmap))
228235

@@ -232,6 +239,7 @@ def Send(endpoints, send_vars, sync=True):
232239
helper.append_op(
233240
type="send",
234241
inputs={"X": send_vars},
242+
outputs={"Out": dummy_output},
235243
attrs={
236244
"endpoints": endpoints,
237245
"epmap": epmap,
@@ -241,7 +249,7 @@ def Send(endpoints, send_vars, sync=True):
241249
helper.append_op(type="send_barrier", attrs={"endpoints": endpoints})
242250

243251

244-
def Recv(endpoints, get_vars, sync=True):
252+
def Recv(endpoints, get_vars, dummy_input=None, sync=True):
245253
"""
246254
Receive variables from server side
247255
@@ -256,13 +264,20 @@ def Recv(endpoints, get_vars, sync=True):
256264
"""
257265
assert (type(get_vars) == list)
258266

267+
if dummy_input is None:
268+
dummy_input = []
269+
elif isinstance(dummy_input, Variable):
270+
dummy_input = [dummy_input]
271+
272+
assert (type(dummy_input) == list)
273+
259274
epmap = endpoints.split(",")
260275
endpoints = list(set(epmap))
261276

262277
helper = LayerHelper("Recv", **locals())
263278
helper.append_op(
264279
type="recv",
265-
inputs={"X": get_vars},
280+
inputs={"X": dummy_input},
266281
outputs={"Out": get_vars},
267282
attrs={"endpoints": endpoints,
268283
"epmap": epmap})

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ def transpile(self,
210210

211211
ps_dispatcher = self.config.split_method(self.pserver_endpoints)
212212
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
213+
self.param_name_to_grad_name = dict()
214+
for param_var, grad_var in self.params_grads:
215+
self.param_name_to_grad_name[param_var.name] = grad_var.name
213216

214217
# step 1: split and create vars, then put splited vars in dicts for later use.
215218
self._init_splited_vars()
@@ -229,34 +232,39 @@ def transpile(self,
229232
random.seed(self.origin_program.random_seed)
230233
random.shuffle(grad_var_mapping_items)
231234

232-
for orig_varname, splited_vars in grad_var_mapping_items:
235+
grad_name_to_send_dummy_out = dict()
236+
for grad_varname, splited_vars in grad_var_mapping_items:
233237
eplist = ps_dispatcher.dispatch(splited_vars)
234238

235239
if not self.config.slice_var_up:
236240
assert (len(splited_vars) == 1)
237241

242+
splited_grad_varname = grad_varname
238243
if len(splited_vars) == 1:
239-
orig_varname = splited_vars[0].name
244+
splited_grad_varname = splited_vars[0].name
240245
index = find_op_by_output_arg(program.global_block(),
241-
orig_varname)
246+
splited_grad_varname)
242247
elif len(splited_vars) > 1:
243-
orig_var = program.global_block().vars[orig_varname]
248+
orig_var = program.global_block().vars[splited_grad_varname]
244249
index = find_op_by_output_arg(program.global_block(),
245-
orig_varname)
250+
splited_grad_varname)
246251
self._insert_split_op(program, orig_var, index, splited_vars)
247252
index += 1
248253
else:
249254
AssertionError("Can not insert the send op by original "
250-
"variable name :", orig_varname)
255+
"variable name :", splited_grad_varname)
251256

257+
dummy_output = program.global_block().create_var()
258+
grad_name_to_send_dummy_out[grad_varname] = dummy_output
252259
program.global_block()._insert_op(
253260
index=index + 1,
254261
type="send",
255262
inputs={"X": splited_vars},
256-
outputs={},
263+
outputs={"Out": dummy_output},
257264
attrs={
258265
"epmap": eplist,
259-
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
266+
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
267+
"sync_mode": not self.sync_mode,
260268
})
261269
for _, var in enumerate(splited_vars):
262270
send_vars.append(var)
@@ -268,7 +276,6 @@ def transpile(self,
268276
outputs={},
269277
attrs={
270278
"endpoints": pserver_endpoints,
271-
"sync_mode": self.sync_mode,
272279
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
273280
})
274281

@@ -284,19 +291,21 @@ def transpile(self,
284291
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
285292

286293
# step4: Concat the parameters splits together after recv.
287-
for varname, splited_var in six.iteritems(self.param_var_mapping):
294+
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
288295
eps = []
289296
for var in splited_var:
290297
index = [v.name for v in recv_vars].index(var.name)
291298
eps.append(eplist[index])
292-
299+
grad_send_dummy_out = grad_name_to_send_dummy_out[
300+
self.param_name_to_grad_name[param_varname]]
293301
program.global_block().append_op(
294302
type="recv",
295-
inputs={},
303+
inputs={"X": [grad_send_dummy_out]},
296304
outputs={"Out": splited_var},
297305
attrs={
298306
"epmap": eps,
299-
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
307+
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
308+
"sync_mode": not self.sync_mode
300309
})
301310

302311
if self.sync_mode:
@@ -309,10 +318,10 @@ def transpile(self,
309318
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
310319
})
311320

312-
for varname, splited_var in six.iteritems(self.param_var_mapping):
321+
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
313322
if len(splited_var) <= 1:
314323
continue
315-
orig_param = program.global_block().vars[varname]
324+
orig_param = program.global_block().vars[param_varname]
316325
program.global_block().append_op(
317326
type="concat",
318327
inputs={"X": splited_var},
@@ -380,7 +389,7 @@ def _get_trainer_startup_program(self,
380389

381390
op = startup_program.global_block().append_op(
382391
type="recv",
383-
inputs={},
392+
inputs={"X": []},
384393
outputs={"Out": splited_var},
385394
attrs={
386395
"epmap": eps,
@@ -786,19 +795,21 @@ def _init_splited_vars(self):
786795
self.config.min_block_size)
787796
assert (len(grad_blocks) == len(param_blocks))
788797

789-
# origin_varname -> [splited_var]
798+
# origin_param_name -> [splited_param_vars]
790799
self.param_var_mapping = self._create_vars_from_blocklist(
791800
self.origin_program, param_blocks)
801+
# origin_grad_name -> [splited_grad_vars]
792802
self.grad_var_mapping = self._create_vars_from_blocklist(
793803
self.origin_program,
794804
grad_blocks,
795805
add_trainer_suffix=self.trainer_num > 1)
806+
# dict(grad_splited_var -> param_splited_var)
796807
self.grad_param_mapping = collections.OrderedDict()
797808
for g, p in zip(grad_blocks, param_blocks):
798809
g_name, g_bid, _ = g.split(":")
799810
p_name, p_bid, _ = p.split(":")
800811
self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \
801-
self.param_var_mapping[p_name][int(p_bid)]
812+
self.param_var_mapping[p_name][int(p_bid)]
802813

803814
# create mapping of endpoint -> split var to create pserver side program
804815
self.param_grad_ep_mapping = collections.OrderedDict()
@@ -919,7 +930,7 @@ def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
919930
index=op_index + 2,
920931
type="send",
921932
inputs={'X': self.trainer_side_table_grad_list},
922-
outputs={},
933+
outputs={'Out': []},
923934
attrs={
924935
"sync_mode": True,
925936
"epmap": pserver_endpoints,

0 commit comments

Comments
 (0)