Skip to content

Commit 6d2ce74

Browse files
committed
Merge remote-tracking branch 'ups/develop' into remove/kwargs
2 parents 808c3ef + f7af695 commit 6d2ce74

18 files changed

+267
-186
lines changed

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -210,43 +210,6 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
210210
return recv_vars;
211211
}
212212

213-
bool MultiDevSSAGraphBuilder::IsDistTrainOp(
214-
ir::Node *node, const std::vector<std::string> &send_vars,
215-
const std::vector<std::string> &recv_vars) const {
216-
if (send_vars.size() == 0 || recv_vars.size() == 0) {
217-
return false;
218-
}
219-
220-
/**
221-
* Check any of opvars contains `.block` and in sendvars
222-
*/
223-
auto checker = [](const std::vector<std::string> &opvars,
224-
const std::vector<std::string> &rpc_vars) -> bool {
225-
for (auto &var : opvars) {
226-
// a variable name with the suffix `.block` means it's a splited
227-
// variable by (DistributeTranspiler)
228-
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
229-
if (var.find(".block") != std::string::npos &&
230-
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
231-
return true;
232-
}
233-
}
234-
return false;
235-
};
236-
237-
std::vector<std::string> input_var_names;
238-
std::vector<std::string> output_var_names;
239-
for (ir::Node *input : node->inputs) {
240-
input_var_names.push_back(input->Name());
241-
}
242-
for (ir::Node *output : node->outputs) {
243-
output_var_names.push_back(output->Name());
244-
}
245-
246-
return checker(output_var_names, send_vars) ||
247-
checker(input_var_names, recv_vars);
248-
}
249-
250213
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
251214
const std::vector<std::string> &var_names) const {
252215
int64_t numel_sum = 0;
@@ -370,7 +333,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
370333
}
371334
}
372335
is_dist_train = true;
373-
} else if (IsDistTrainOp(node, send_vars, recv_vars)) {
336+
} else if (boost::get<int>(node->Op()->GetAttr(
337+
OpProtoAndCheckerMaker::OpRoleAttrName())) ==
338+
static_cast<int>(OpRole::kDist)) {
374339
int op_dev_id = CreateDistTrainOp(&result, node);
375340
if (node->Op()->Type() == "concat") {
376341
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
@@ -736,6 +701,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
736701
.emplace(varname, op_dev_id);
737702
}
738703
} else {
704+
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
739705
PADDLE_THROW(
740706
"the distribute training related op should be in [split_byref, "
741707
"concat].");

paddle/fluid/framework/details/multi_devices_graph_pass.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
5151
int CreateRPCOp(ir::Graph *result, ir::Node *node) const;
5252
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
5353

54-
/**
55-
* Is this operator as the end-point operator before/after send operator.
56-
*/
57-
bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
58-
const std::vector<std::string> &recv_vars) const;
59-
6054
std::vector<std::string> FindDistTrainSendVars(
6155
const std::vector<ir::Node *> &nodes) const;
6256

paddle/fluid/framework/op_proto_maker.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
120120
{static_cast<int>(OpRole::kForward),
121121
static_cast<int>(OpRole::kBackward),
122122
static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC),
123+
static_cast<int>(OpRole::kDist), static_cast<int>(OpRole::kLRSched),
123124
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
124125
static_cast<int>(OpRole::kLoss) |
125126
static_cast<int>(OpRole::kBackward),

paddle/fluid/framework/op_proto_maker.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,13 @@ enum class OpRole {
2626
kForward = 0x0000,
2727
kBackward = 0x0001,
2828
kOptimize = 0x0002,
29+
// RPC role is for send/recv releated op
2930
kRPC = 0x0003,
31+
// Dist role is for split_byref/split_selected_rows/concat
32+
// used for distributed training.
33+
kDist = 0x0004,
34+
// Tag all learning rate scheduler operators.
35+
kLRSched = 0x0005,
3036

3137
kLoss = 0x0100,
3238
// The default value of op's role. This should be only used for unittests and

paddle/fluid/operators/distributed/variable_response.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,14 @@ bool VariableResponse::CopyLodTensorData(
9292
::google::protobuf::io::CodedInputStream* input,
9393
const platform::DeviceContext& ctx, const framework::DDim& dims,
9494
int length) {
95+
auto server_var = GetVar();
96+
if (!server_var) {
97+
LOG(ERROR) << "recved var should not on current server: "
98+
<< meta_.varname();
99+
return false;
100+
}
95101
auto* tensor = GetVar()->GetMutable<framework::LoDTensor>();
96102
tensor->Resize(dims);
97-
98103
framework::LoD lod;
99104
for (int i = 0; i < meta_.lod_level(); ++i) {
100105
framework::Vector<size_t> v;
@@ -107,7 +112,6 @@ bool VariableResponse::CopyLodTensorData(
107112

108113
void* tensor_data =
109114
tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type()));
110-
111115
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
112116
return false;
113117
}

paddle/fluid/operators/math/selected_rows_functor.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
107107
PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);
108108

109109
auto& in1_value = input1.value();
110-
framework::Vector<int64_t> in1_rows(input1.rows());
110+
auto& in1_rows = input1.rows();
111111

112112
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
113113
PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
@@ -206,7 +206,7 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
206206
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
207207

208208
auto& in1_value = input1.value();
209-
framework::Vector<int64_t> in1_rows(input1.rows());
209+
auto& in1_rows = input1.rows();
210210

211211
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
212212
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);

paddle/fluid/operators/math/selected_rows_functor_test.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ limitations under the License. */
2020
TEST(selected_rows_functor, gpu_add) {
2121
paddle::platform::CUDAPlace gpu_place(0);
2222
paddle::platform::CPUPlace cpu_place;
23-
paddle::platform::CUDADeviceContext ctx(gpu_place);
23+
paddle::platform::CUDADeviceContext& ctx =
24+
*reinterpret_cast<paddle::platform::CUDADeviceContext*>(
25+
paddle::platform::DeviceContextPool::Instance().Get(gpu_place));
2426
paddle::operators::math::SetConstant<paddle::platform::CUDADeviceContext,
2527
float>
2628
functor;
@@ -132,7 +134,9 @@ TEST(selected_rows_functor, gpu_add) {
132134
TEST(selected_rows_functor, gpu_add_to) {
133135
paddle::platform::CUDAPlace gpu_place(0);
134136
paddle::platform::CPUPlace cpu_place;
135-
paddle::platform::CUDADeviceContext ctx(gpu_place);
137+
paddle::platform::CUDADeviceContext& ctx =
138+
*reinterpret_cast<paddle::platform::CUDADeviceContext*>(
139+
paddle::platform::DeviceContextPool::Instance().Get(gpu_place));
136140
paddle::operators::math::SetConstant<paddle::platform::CUDADeviceContext,
137141
float>
138142
functor;

paddle/fluid/pybind/const_value.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ void BindConstValue(pybind11::module* m) {
3636
.value("Backward", framework::OpRole::kBackward)
3737
.value("Optimize", framework::OpRole::kOptimize)
3838
.value("Loss", framework::OpRole::kLoss)
39-
.value("RPC", framework::OpRole::kRPC);
39+
.value("RPC", framework::OpRole::kRPC)
40+
.value("Dist", framework::OpRole::kDist)
41+
.value("LRSched", framework::OpRole::kLRSched);
4042

4143
op_proto_and_checker_maker.def(
4244
"kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName);

python/paddle/fluid/framework.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,30 @@ def _optimized_guard(self, param_and_grads):
15091509
self._op_role_var = []
15101510
self._current_role = OpRole.Forward
15111511

1512+
@contextlib.contextmanager
1513+
def _lr_schedule_guard(self):
1514+
"""
1515+
A with guard to set :code:`LRSched` :code:`OpRole` and
1516+
:code:`OpRoleVar` automatically. The :code:`OpRoleVar` is
1517+
set to the target learning rate.
1518+
1519+
Notes: This is a very low level API. Users should not use it directly.
1520+
1521+
1522+
Examples:
1523+
1524+
>>> p, g = backward(...)
1525+
>>> with program.lr_schedule_guard():
1526+
>>> lr = lr * decay
1527+
"""
1528+
OpRole = core.op_proto_and_checker_maker.OpRole
1529+
self._current_role = OpRole.LRSched
1530+
# TODO(typhoonzero): how to set target learning rate var
1531+
self._op_role_var = []
1532+
yield
1533+
self._op_role_var = []
1534+
self._current_role = OpRole.Forward
1535+
15121536
def __str__(self):
15131537
"""
15141538
Get the protobuf debug string of this Program.

0 commit comments

Comments
 (0)