Skip to content

Commit 86776a1

Browse files
typhoonzerojacquesqiao
authored andcommitted
Resovle multi gpu async deps (#12828)
* dist transpiler add control dependency var between send and recv * fix async deps * follow comments and refine * fix deps connect for rpc ops
1 parent 9a43c9a commit 86776a1

File tree

6 files changed

+50
-9
lines changed

6 files changed

+50
-9
lines changed

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
763763
// Create RPC related op handles that connects its in ops and out ops.
764764
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
765765
ir::Node *node) const {
766+
// FIXME(typhoonzero): Cleanup this deps for both sync mode and async mode
767+
// put them into transpiler.
766768
int op_dev_id = -1;
767769
if (node->Op()->Type() == "send") {
768770
// TODO(paddle-dev): getting the first var is not safe.
@@ -771,26 +773,42 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
771773
"This hack no longer holds, please fix.");
772774
// the variable name which contains .block means it was splited by
773775
// split_byref op
774-
// so that we can balance the variable blocks to all the pserver
775-
// instances.
776776
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
777777
node->inputs[0]->Name().find(".block") == std::string::npos) {
778778
std::vector<std::string> input_var_names;
779779
for (ir::Node *n : node->inputs) {
780780
input_var_names.push_back(n->Name());
781781
}
782-
op_dev_id = GetAppropriateDeviceID(input_var_names);
782+
auto send_param_grad = boost::get<std::vector<std::string>>(
783+
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
784+
PADDLE_ENFORCE_EQ(send_param_grad.size(), 2U);
785+
op_dev_id = GetAppropriateDeviceID({send_param_grad[1]});
786+
VLOG(10) << "send grad " << input_var_names[0] << " origin "
787+
<< send_param_grad[1] << " place: " << op_dev_id;
783788
for (auto &varname : input_var_names) {
784789
result->Get<ShardedVarDevice>(kShardedVarDevice)
785790
.emplace(varname, op_dev_id);
786791
}
792+
result->Get<ShardedVarDevice>(kShardedVarDevice)
793+
.emplace(send_param_grad[1], op_dev_id);
787794
}
788795
} else if (node->Op()->Type() == "recv") {
789796
std::vector<std::string> output_var_names;
790797
for (ir::Node *n : node->outputs) {
791798
output_var_names.push_back(n->Name());
792799
}
793-
op_dev_id = GetAppropriateDeviceID(output_var_names);
800+
auto recv_param_grad = boost::get<std::vector<std::string>>(
801+
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
802+
// FIXME(typhoonzero): assume each recv op output one param
803+
// Use the same place as send.
804+
if (recv_param_grad.size() == 2U) {
805+
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]);
806+
VLOG(10) << "recv param " << recv_param_grad[0]
807+
<< " get grad place: " << recv_param_grad[1]
808+
<< " place: " << op_dev_id;
809+
} else {
810+
op_dev_id = GetAppropriateDeviceID(output_var_names);
811+
}
794812
for (auto &varname : output_var_names) {
795813
result->Get<ShardedVarDevice>(kShardedVarDevice)
796814
.emplace(varname, op_dev_id);

paddle/fluid/framework/ir/node.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace framework {
1919
namespace ir {
20-
const char Node::kControlDepVarName[] = "__control_var";
20+
constexpr char Node::kControlDepVarName[];
2121
} // namespace ir
2222
} // namespace framework
2323
} // namespace paddle

paddle/fluid/framework/ir/node.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace ir {
2727
class Node {
2828
public:
2929
enum class Type { kOperation, kVariable };
30-
static const char kControlDepVarName[];
30+
static constexpr char kControlDepVarName[] = "__control_var";
3131

3232
explicit Node(const std::string& name, Type type)
3333
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}

paddle/fluid/pybind/const_value.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/pybind/const_value.h"
16-
#include <paddle/fluid/framework/op_proto_maker.h>
16+
#include "paddle/fluid/framework/ir/node.h"
17+
#include "paddle/fluid/framework/op_proto_maker.h"
1718
#include "paddle/fluid/framework/operator.h"
1819

1920
namespace paddle {
@@ -24,6 +25,8 @@ void BindConstValue(pybind11::module* m) {
2425
m->def("kTempVarName", [] { return framework::kTempVarName; });
2526
m->def("kGradVarSuffix", [] { return framework::kGradVarSuffix; });
2627
m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; });
28+
m->def("kControlDepVarName",
29+
[] { return framework::ir::Node::kControlDepVarName; });
2730

2831
auto op_proto_and_checker_maker =
2932
m->def_submodule("op_proto_and_checker_maker");

python/paddle/fluid/framework.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@
4949
TEMP_VAR_NAME = core.kTempVarName()
5050
GRAD_VAR_SUFFIX = core.kGradVarSuffix()
5151
ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
52+
CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
53+
54+
55+
def generate_control_dev_var_name():
56+
import random
57+
return CONTROL_DEP_VAR_PREFIX + "@" + str(random.random())
5258

5359

5460
def grad_var_name(var_name):

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,10 @@ def transpile(self,
211211
ps_dispatcher = self.config.split_method(self.pserver_endpoints)
212212
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
213213
self.param_name_to_grad_name = dict()
214+
self.grad_name_to_param_name = dict()
214215
for param_var, grad_var in self.params_grads:
215216
self.param_name_to_grad_name[param_var.name] = grad_var.name
217+
self.grad_name_to_param_name[grad_var.name] = param_var.name
216218

217219
# step 1: split and create vars, then put splited vars in dicts for later use.
218220
self._init_splited_vars()
@@ -254,8 +256,10 @@ def transpile(self,
254256
AssertionError("Can not insert the send op by original "
255257
"variable name :", splited_grad_varname)
256258

257-
dummy_output = program.global_block().create_var()
259+
dummy_output = program.global_block().create_var(
260+
name=framework.generate_control_dev_var_name())
258261
grad_name_to_send_dummy_out[grad_varname] = dummy_output
262+
259263
program.global_block()._insert_op(
260264
index=index + 1,
261265
type="send",
@@ -264,6 +268,8 @@ def transpile(self,
264268
attrs={
265269
"epmap": eplist,
266270
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
271+
OP_ROLE_VAR_ATTR_NAME:
272+
[self.grad_name_to_param_name[grad_varname], grad_varname],
267273
"sync_mode": not self.sync_mode,
268274
})
269275
for _, var in enumerate(splited_vars):
@@ -305,6 +311,10 @@ def transpile(self,
305311
attrs={
306312
"epmap": eps,
307313
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
314+
OP_ROLE_VAR_ATTR_NAME: [
315+
param_varname,
316+
self.param_name_to_grad_name[param_varname]
317+
],
308318
"sync_mode": not self.sync_mode
309319
})
310320

@@ -934,7 +944,11 @@ def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
934944
attrs={
935945
"sync_mode": True,
936946
"epmap": pserver_endpoints,
937-
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
947+
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
948+
OP_ROLE_VAR_ATTR_NAME: [
949+
self.grad_name_to_param_name[table_grad_name],
950+
table_grad_name
951+
]
938952
})
939953
break
940954

0 commit comments

Comments
 (0)