Skip to content

Commit bfa7b3e

Browse files
authored
Merge pull request #12907 from jacquesqiao/cherry-pick-add-dependency-to-send-recv
Cherry pick add dependency to send recv
2 parents 482d297 + 86776a1 commit bfa7b3e

File tree

10 files changed

+107
-40
lines changed

10 files changed

+107
-40
lines changed

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
763763
// Create RPC related op handles that connects its in ops and out ops.
764764
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
765765
ir::Node *node) const {
766+
// FIXME(typhoonzero): Cleanup this deps for both sync mode and async mode
767+
// put them into transpiler.
766768
int op_dev_id = -1;
767769
if (node->Op()->Type() == "send") {
768770
// TODO(paddle-dev): getting the first var is not safe.
@@ -771,26 +773,42 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
771773
"This hack no longer holds, please fix.");
772774
// the variable name which contains .block means it was splited by
773775
// split_byref op
774-
// so that we can balance the variable blocks to all the pserver
775-
// instances.
776776
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
777777
node->inputs[0]->Name().find(".block") == std::string::npos) {
778778
std::vector<std::string> input_var_names;
779779
for (ir::Node *n : node->inputs) {
780780
input_var_names.push_back(n->Name());
781781
}
782-
op_dev_id = GetAppropriateDeviceID(input_var_names);
782+
auto send_param_grad = boost::get<std::vector<std::string>>(
783+
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
784+
PADDLE_ENFORCE_EQ(send_param_grad.size(), 2U);
785+
op_dev_id = GetAppropriateDeviceID({send_param_grad[1]});
786+
VLOG(10) << "send grad " << input_var_names[0] << " origin "
787+
<< send_param_grad[1] << " place: " << op_dev_id;
783788
for (auto &varname : input_var_names) {
784789
result->Get<ShardedVarDevice>(kShardedVarDevice)
785790
.emplace(varname, op_dev_id);
786791
}
792+
result->Get<ShardedVarDevice>(kShardedVarDevice)
793+
.emplace(send_param_grad[1], op_dev_id);
787794
}
788795
} else if (node->Op()->Type() == "recv") {
789796
std::vector<std::string> output_var_names;
790797
for (ir::Node *n : node->outputs) {
791798
output_var_names.push_back(n->Name());
792799
}
793-
op_dev_id = GetAppropriateDeviceID(output_var_names);
800+
auto recv_param_grad = boost::get<std::vector<std::string>>(
801+
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
802+
// FIXME(typhoonzero): assume each recv op output one param
803+
// Use the same place as send.
804+
if (recv_param_grad.size() == 2U) {
805+
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]);
806+
VLOG(10) << "recv param " << recv_param_grad[0]
807+
<< " get grad place: " << recv_param_grad[1]
808+
<< " place: " << op_dev_id;
809+
} else {
810+
op_dev_id = GetAppropriateDeviceID(output_var_names);
811+
}
794812
for (auto &varname : output_var_names) {
795813
result->Get<ShardedVarDevice>(kShardedVarDevice)
796814
.emplace(varname, op_dev_id);

paddle/fluid/framework/ir/node.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace framework {
1919
namespace ir {
20-
const char Node::kControlDepVarName[] = "__control_var";
20+
constexpr char Node::kControlDepVarName[];
2121
} // namespace ir
2222
} // namespace framework
2323
} // namespace paddle

paddle/fluid/framework/ir/node.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace ir {
2727
class Node {
2828
public:
2929
enum class Type { kOperation, kVariable };
30-
static const char kControlDepVarName[];
30+
static constexpr char kControlDepVarName[] = "__control_var";
3131

3232
explicit Node(const std::string& name, Type type)
3333
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}

paddle/fluid/operators/recv_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class RecvOp : public framework::OperatorBase {
5757
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
5858
public:
5959
void Make() {
60+
AddInput("X", "(Any) Dummy inputs, used for control dependency")
61+
.AsDuplicable();
6062
AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable();
6163
AddComment(R"DOC(
6264
Recv operator

paddle/fluid/operators/send_barrier_op.cc

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,19 @@ class SendBarrierOp : public framework::OperatorBase {
3737
void RunImpl(const framework::Scope& scope,
3838
const platform::Place& place) const override {
3939
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
40-
bool sync_mode = Attr<bool>("sync_mode");
4140

4241
distributed::RPCClient* rpc_client =
4342
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
4443

45-
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
44+
VLOG(3) << "SendBarrierOp sync";
4645

4746
// need to wait before sending send_barrier message
4847
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
49-
if (sync_mode) {
50-
for (auto& ep : eps) {
51-
VLOG(3) << "send barrier, ep: " << ep;
52-
rpc_client->AsyncSendBatchBarrier(ep);
53-
}
54-
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
48+
for (auto& ep : eps) {
49+
VLOG(3) << "send barrier, ep: " << ep;
50+
rpc_client->AsyncSendBatchBarrier(ep);
5551
}
52+
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
5653
}
5754
};
5855

@@ -70,7 +67,6 @@ the Parameter Server would knew all variables have been sent.
7067
"(string vector, default 127.0.0.1:6164)"
7168
"Server endpoints to send variables to.")
7269
.SetDefault({"127.0.0.1:6164"});
73-
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
7470
}
7571
};
7672

paddle/fluid/operators/send_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
6666
void Make() {
6767
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
6868
.AsDuplicable();
69+
AddOutput("Out", "(Any) Dummy outputs, used for control dependency")
70+
.AsDuplicable();
6971
AddComment(R"DOC(
7072
Send operator
7173

paddle/fluid/pybind/const_value.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/pybind/const_value.h"
16-
#include <paddle/fluid/framework/op_proto_maker.h>
16+
#include "paddle/fluid/framework/ir/node.h"
17+
#include "paddle/fluid/framework/op_proto_maker.h"
1718
#include "paddle/fluid/framework/operator.h"
1819

1920
namespace paddle {
@@ -24,6 +25,8 @@ void BindConstValue(pybind11::module* m) {
2425
m->def("kTempVarName", [] { return framework::kTempVarName; });
2526
m->def("kGradVarSuffix", [] { return framework::kGradVarSuffix; });
2627
m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; });
28+
m->def("kControlDepVarName",
29+
[] { return framework::ir::Node::kControlDepVarName; });
2730

2831
auto op_proto_and_checker_maker =
2932
m->def_submodule("op_proto_and_checker_maker");

python/paddle/fluid/framework.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@
4949
TEMP_VAR_NAME = core.kTempVarName()
5050
GRAD_VAR_SUFFIX = core.kGradVarSuffix()
5151
ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
52+
CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
53+
54+
55+
def generate_control_dev_var_name():
56+
import random
57+
return CONTROL_DEP_VAR_PREFIX + "@" + str(random.random())
5258

5359

5460
def grad_var_name(var_name):

python/paddle/fluid/layers/io.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .. import core
2525
from ..executor import global_scope
2626
from ..framework import convert_np_dtype_to_dtype_, default_main_program, \
27-
default_startup_program, program_guard, Program
27+
default_startup_program, program_guard, Program, Variable
2828
from ..layer_helper import LayerHelper
2929
from ..unique_name import generate as unique_name
3030

@@ -209,7 +209,7 @@ def complete_op(self):
209209
})
210210

211211

212-
def Send(endpoints, send_vars, sync=True):
212+
def Send(endpoints, send_vars, dummy_output=None, sync=True):
213213
"""
214214
Send variables to the server side, and get vars from server
215215
side when server have finished running server side program.
@@ -223,6 +223,13 @@ def Send(endpoints, send_vars, sync=True):
223223
"""
224224
assert (type(send_vars) == list)
225225

226+
if dummy_output is None:
227+
dummy_output = []
228+
elif isinstance(dummy_output, Variable):
229+
dummy_output = [dummy_output]
230+
231+
assert (type(dummy_output) == list)
232+
226233
epmap = endpoints.split(",")
227234
endpoints = list(set(epmap))
228235

@@ -232,6 +239,7 @@ def Send(endpoints, send_vars, sync=True):
232239
helper.append_op(
233240
type="send",
234241
inputs={"X": send_vars},
242+
outputs={"Out": dummy_output},
235243
attrs={
236244
"endpoints": endpoints,
237245
"epmap": epmap,
@@ -241,7 +249,7 @@ def Send(endpoints, send_vars, sync=True):
241249
helper.append_op(type="send_barrier", attrs={"endpoints": endpoints})
242250

243251

244-
def Recv(endpoints, get_vars, sync=True):
252+
def Recv(endpoints, get_vars, dummy_input=None, sync=True):
245253
"""
246254
Receive variables from server side
247255
@@ -256,13 +264,20 @@ def Recv(endpoints, get_vars, sync=True):
256264
"""
257265
assert (type(get_vars) == list)
258266

267+
if dummy_input is None:
268+
dummy_input = []
269+
elif isinstance(dummy_input, Variable):
270+
dummy_input = [dummy_input]
271+
272+
assert (type(dummy_input) == list)
273+
259274
epmap = endpoints.split(",")
260275
endpoints = list(set(epmap))
261276

262277
helper = LayerHelper("Recv", **locals())
263278
helper.append_op(
264279
type="recv",
265-
inputs={"X": get_vars},
280+
inputs={"X": dummy_input},
266281
outputs={"Out": get_vars},
267282
attrs={"endpoints": endpoints,
268283
"epmap": epmap})

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ def transpile(self,
210210

211211
ps_dispatcher = self.config.split_method(self.pserver_endpoints)
212212
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
213+
self.param_name_to_grad_name = dict()
214+
self.grad_name_to_param_name = dict()
215+
for param_var, grad_var in self.params_grads:
216+
self.param_name_to_grad_name[param_var.name] = grad_var.name
217+
self.grad_name_to_param_name[grad_var.name] = param_var.name
213218

214219
# step 1: split and create vars, then put splited vars in dicts for later use.
215220
self._init_splited_vars()
@@ -229,34 +234,43 @@ def transpile(self,
229234
random.seed(self.origin_program.random_seed)
230235
random.shuffle(grad_var_mapping_items)
231236

232-
for orig_varname, splited_vars in grad_var_mapping_items:
237+
grad_name_to_send_dummy_out = dict()
238+
for grad_varname, splited_vars in grad_var_mapping_items:
233239
eplist = ps_dispatcher.dispatch(splited_vars)
234240

235241
if not self.config.slice_var_up:
236242
assert (len(splited_vars) == 1)
237243

244+
splited_grad_varname = grad_varname
238245
if len(splited_vars) == 1:
239-
orig_varname = splited_vars[0].name
246+
splited_grad_varname = splited_vars[0].name
240247
index = find_op_by_output_arg(program.global_block(),
241-
orig_varname)
248+
splited_grad_varname)
242249
elif len(splited_vars) > 1:
243-
orig_var = program.global_block().vars[orig_varname]
250+
orig_var = program.global_block().vars[splited_grad_varname]
244251
index = find_op_by_output_arg(program.global_block(),
245-
orig_varname)
252+
splited_grad_varname)
246253
self._insert_split_op(program, orig_var, index, splited_vars)
247254
index += 1
248255
else:
249256
AssertionError("Can not insert the send op by original "
250-
"variable name :", orig_varname)
257+
"variable name :", splited_grad_varname)
258+
259+
dummy_output = program.global_block().create_var(
260+
name=framework.generate_control_dev_var_name())
261+
grad_name_to_send_dummy_out[grad_varname] = dummy_output
251262

252263
program.global_block()._insert_op(
253264
index=index + 1,
254265
type="send",
255266
inputs={"X": splited_vars},
256-
outputs={},
267+
outputs={"Out": dummy_output},
257268
attrs={
258269
"epmap": eplist,
259-
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
270+
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
271+
OP_ROLE_VAR_ATTR_NAME:
272+
[self.grad_name_to_param_name[grad_varname], grad_varname],
273+
"sync_mode": not self.sync_mode,
260274
})
261275
for _, var in enumerate(splited_vars):
262276
send_vars.append(var)
@@ -268,7 +282,6 @@ def transpile(self,
268282
outputs={},
269283
attrs={
270284
"endpoints": pserver_endpoints,
271-
"sync_mode": self.sync_mode,
272285
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
273286
})
274287

@@ -284,19 +297,25 @@ def transpile(self,
284297
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
285298

286299
# step4: Concat the parameters splits together after recv.
287-
for varname, splited_var in six.iteritems(self.param_var_mapping):
300+
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
288301
eps = []
289302
for var in splited_var:
290303
index = [v.name for v in recv_vars].index(var.name)
291304
eps.append(eplist[index])
292-
305+
grad_send_dummy_out = grad_name_to_send_dummy_out[
306+
self.param_name_to_grad_name[param_varname]]
293307
program.global_block().append_op(
294308
type="recv",
295-
inputs={},
309+
inputs={"X": [grad_send_dummy_out]},
296310
outputs={"Out": splited_var},
297311
attrs={
298312
"epmap": eps,
299-
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
313+
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
314+
OP_ROLE_VAR_ATTR_NAME: [
315+
param_varname,
316+
self.param_name_to_grad_name[param_varname]
317+
],
318+
"sync_mode": not self.sync_mode
300319
})
301320

302321
if self.sync_mode:
@@ -309,10 +328,10 @@ def transpile(self,
309328
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
310329
})
311330

312-
for varname, splited_var in six.iteritems(self.param_var_mapping):
331+
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
313332
if len(splited_var) <= 1:
314333
continue
315-
orig_param = program.global_block().vars[varname]
334+
orig_param = program.global_block().vars[param_varname]
316335
program.global_block().append_op(
317336
type="concat",
318337
inputs={"X": splited_var},
@@ -380,7 +399,7 @@ def _get_trainer_startup_program(self,
380399

381400
op = startup_program.global_block().append_op(
382401
type="recv",
383-
inputs={},
402+
inputs={"X": []},
384403
outputs={"Out": splited_var},
385404
attrs={
386405
"epmap": eps,
@@ -786,19 +805,21 @@ def _init_splited_vars(self):
786805
self.config.min_block_size)
787806
assert (len(grad_blocks) == len(param_blocks))
788807

789-
# origin_varname -> [splited_var]
808+
# origin_param_name -> [splited_param_vars]
790809
self.param_var_mapping = self._create_vars_from_blocklist(
791810
self.origin_program, param_blocks)
811+
# origin_grad_name -> [splited_grad_vars]
792812
self.grad_var_mapping = self._create_vars_from_blocklist(
793813
self.origin_program,
794814
grad_blocks,
795815
add_trainer_suffix=self.trainer_num > 1)
816+
# dict(grad_splited_var -> param_splited_var)
796817
self.grad_param_mapping = collections.OrderedDict()
797818
for g, p in zip(grad_blocks, param_blocks):
798819
g_name, g_bid, _ = g.split(":")
799820
p_name, p_bid, _ = p.split(":")
800821
self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \
801-
self.param_var_mapping[p_name][int(p_bid)]
822+
self.param_var_mapping[p_name][int(p_bid)]
802823

803824
# create mapping of endpoint -> split var to create pserver side program
804825
self.param_grad_ep_mapping = collections.OrderedDict()
@@ -919,11 +940,15 @@ def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
919940
index=op_index + 2,
920941
type="send",
921942
inputs={'X': self.trainer_side_table_grad_list},
922-
outputs={},
943+
outputs={'Out': []},
923944
attrs={
924945
"sync_mode": True,
925946
"epmap": pserver_endpoints,
926-
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
947+
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
948+
OP_ROLE_VAR_ATTR_NAME: [
949+
self.grad_name_to_param_name[table_grad_name],
950+
table_grad_name
951+
]
927952
})
928953
break
929954

0 commit comments

Comments
 (0)