Skip to content

Commit 29c63d1

Browse files
authored
[Feature] dist op role and lr op role, to support memory optimize with dist training (#13220)
* wip * clean up * should fix running with memopt * add ut * mark lr schedule op role * hide lr_schedule_guard * use op_role_var instead of ufind * unify dist test name * wip for py3 support * fix var deref * fix python3 mem_opt order * remove comments
1 parent 2d97903 commit 29c63d1

15 files changed

+257
-181
lines changed

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -210,43 +210,6 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
210210
return recv_vars;
211211
}
212212

213-
bool MultiDevSSAGraphBuilder::IsDistTrainOp(
214-
ir::Node *node, const std::vector<std::string> &send_vars,
215-
const std::vector<std::string> &recv_vars) const {
216-
if (send_vars.size() == 0 || recv_vars.size() == 0) {
217-
return false;
218-
}
219-
220-
/**
221-
* Check any of opvars contains `.block` and in sendvars
222-
*/
223-
auto checker = [](const std::vector<std::string> &opvars,
224-
const std::vector<std::string> &rpc_vars) -> bool {
225-
for (auto &var : opvars) {
226-
// a variable name with the suffix `.block` means it's a splited
227-
// variable by (DistributeTranspiler)
228-
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
229-
if (var.find(".block") != std::string::npos &&
230-
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
231-
return true;
232-
}
233-
}
234-
return false;
235-
};
236-
237-
std::vector<std::string> input_var_names;
238-
std::vector<std::string> output_var_names;
239-
for (ir::Node *input : node->inputs) {
240-
input_var_names.push_back(input->Name());
241-
}
242-
for (ir::Node *output : node->outputs) {
243-
output_var_names.push_back(output->Name());
244-
}
245-
246-
return checker(output_var_names, send_vars) ||
247-
checker(input_var_names, recv_vars);
248-
}
249-
250213
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
251214
const std::vector<std::string> &var_names) const {
252215
int64_t numel_sum = 0;
@@ -370,7 +333,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
370333
}
371334
}
372335
is_dist_train = true;
373-
} else if (IsDistTrainOp(node, send_vars, recv_vars)) {
336+
} else if (boost::get<int>(node->Op()->GetAttr(
337+
OpProtoAndCheckerMaker::OpRoleAttrName())) ==
338+
static_cast<int>(OpRole::kDist)) {
374339
int op_dev_id = CreateDistTrainOp(&result, node);
375340
if (node->Op()->Type() == "concat") {
376341
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
@@ -736,6 +701,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
736701
.emplace(varname, op_dev_id);
737702
}
738703
} else {
704+
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
739705
PADDLE_THROW(
740706
"the distribute training related op should be in [split_byref, "
741707
"concat].");

paddle/fluid/framework/details/multi_devices_graph_pass.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
5151
int CreateRPCOp(ir::Graph *result, ir::Node *node) const;
5252
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
5353

54-
/**
55-
* Is this operator as the end-point operator before/after send operator.
56-
*/
57-
bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
58-
const std::vector<std::string> &recv_vars) const;
59-
6054
std::vector<std::string> FindDistTrainSendVars(
6155
const std::vector<ir::Node *> &nodes) const;
6256

paddle/fluid/framework/op_proto_maker.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
120120
{static_cast<int>(OpRole::kForward),
121121
static_cast<int>(OpRole::kBackward),
122122
static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC),
123+
static_cast<int>(OpRole::kDist), static_cast<int>(OpRole::kLRSched),
123124
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
124125
static_cast<int>(OpRole::kLoss) |
125126
static_cast<int>(OpRole::kBackward),

paddle/fluid/framework/op_proto_maker.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,13 @@ enum class OpRole {
2626
kForward = 0x0000,
2727
kBackward = 0x0001,
2828
kOptimize = 0x0002,
29+
// RPC role is for send/recv releated op
2930
kRPC = 0x0003,
31+
// Dist role is for split_byref/split_selected_rows/concat
32+
// used for distributed training.
33+
kDist = 0x0004,
34+
// Tag all learning rate scheduler operators.
35+
kLRSched = 0x0005,
3036

3137
kLoss = 0x0100,
3238
// The default value of op's role. This should be only used for unittests and

paddle/fluid/operators/distributed/variable_response.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,14 @@ bool VariableResponse::CopyLodTensorData(
9292
::google::protobuf::io::CodedInputStream* input,
9393
const platform::DeviceContext& ctx, const framework::DDim& dims,
9494
int length) {
95+
auto server_var = GetVar();
96+
if (!server_var) {
97+
LOG(ERROR) << "recved var should not on current server: "
98+
<< meta_.varname();
99+
return false;
100+
}
95101
auto* tensor = GetVar()->GetMutable<framework::LoDTensor>();
96102
tensor->Resize(dims);
97-
98103
framework::LoD lod;
99104
for (int i = 0; i < meta_.lod_level(); ++i) {
100105
framework::Vector<size_t> v;
@@ -107,7 +112,6 @@ bool VariableResponse::CopyLodTensorData(
107112

108113
void* tensor_data =
109114
tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type()));
110-
111115
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
112116
return false;
113117
}

paddle/fluid/pybind/const_value.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ void BindConstValue(pybind11::module* m) {
3636
.value("Backward", framework::OpRole::kBackward)
3737
.value("Optimize", framework::OpRole::kOptimize)
3838
.value("Loss", framework::OpRole::kLoss)
39-
.value("RPC", framework::OpRole::kRPC);
39+
.value("RPC", framework::OpRole::kRPC)
40+
.value("Dist", framework::OpRole::kDist)
41+
.value("LRSched", framework::OpRole::kLRSched);
4042

4143
op_proto_and_checker_maker.def(
4244
"kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName);

python/paddle/fluid/framework.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,30 @@ def _optimized_guard(self, param_and_grads):
15091509
self._op_role_var = []
15101510
self._current_role = OpRole.Forward
15111511

1512+
@contextlib.contextmanager
1513+
def _lr_schedule_guard(self):
1514+
"""
1515+
A with guard to set :code:`LRSched` :code:`OpRole` and
1516+
:code:`OpRoleVar` automatically. The :code:`OpRoleVar` is
1517+
set to the target learning rate.
1518+
1519+
Notes: This is a very low level API. Users should not use it directly.
1520+
1521+
1522+
Examples:
1523+
1524+
>>> p, g = backward(...)
1525+
>>> with program.lr_schedule_guard():
1526+
>>> lr = lr * decay
1527+
"""
1528+
OpRole = core.op_proto_and_checker_maker.OpRole
1529+
self._current_role = OpRole.LRSched
1530+
# TODO(typhoonzero): how to set target learning rate var
1531+
self._op_role_var = []
1532+
yield
1533+
self._op_role_var = []
1534+
self._current_role = OpRole.Forward
1535+
15121536
def __str__(self):
15131537
"""
15141538
Get the protobuf debug string of this Program.

python/paddle/fluid/layers/learning_rate_scheduler.py

Lines changed: 73 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from . import ops
2828
from . import tensor
2929
from ..initializer import init_on_cpu
30-
from ..framework import default_main_program, Parameter
30+
from ..framework import default_main_program, Parameter, unique_name
3131

3232
__all__ = [
3333
'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
@@ -63,11 +63,12 @@ def noam_decay(d_model, warmup_steps):
6363
Returns:
6464
The decayed learning rate.
6565
"""
66-
global_step = _decay_step_counter(1)
66+
with default_main_program()._lr_schedule_guard():
67+
global_step = _decay_step_counter(1)
6768

68-
a = global_step**-0.5
69-
b = (warmup_steps**-1.5) * global_step
70-
lr_value = (d_model**-0.5) * ops.elementwise_min(a, b)
69+
a = global_step**-0.5
70+
b = (warmup_steps**-1.5) * global_step
71+
lr_value = (d_model**-0.5) * ops.elementwise_min(a, b)
7172

7273
return lr_value
7374

@@ -108,14 +109,15 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
108109
sgd_optimizer.minimize(avg_cost)
109110
110111
"""
111-
global_step = _decay_step_counter()
112+
with default_main_program()._lr_schedule_guard():
113+
global_step = _decay_step_counter()
112114

113-
div_res = global_step / decay_steps
114-
if staircase:
115-
div_res = ops.floor(div_res)
116-
decayed_lr = learning_rate * (decay_rate**div_res)
115+
div_res = global_step / decay_steps
116+
if staircase:
117+
div_res = ops.floor(div_res)
118+
decayed_lr = learning_rate * (decay_rate**div_res)
117119

118-
return decayed_lr
120+
return decayed_lr
119121

120122

121123
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
@@ -136,14 +138,15 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
136138
Returns:
137139
The decayed learning rate
138140
"""
139-
global_step = _decay_step_counter()
141+
with default_main_program()._lr_schedule_guard():
142+
global_step = _decay_step_counter()
140143

141-
div_res = global_step / decay_steps
142-
if staircase:
143-
div_res = ops.floor(div_res)
144-
decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
144+
div_res = global_step / decay_steps
145+
if staircase:
146+
div_res = ops.floor(div_res)
147+
decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
145148

146-
return decayed_lr
149+
return decayed_lr
147150

148151

149152
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
@@ -181,15 +184,16 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
181184
staircase=True))
182185
sgd_optimizer.minimize(avg_cost)
183186
"""
184-
global_step = _decay_step_counter()
187+
with default_main_program()._lr_schedule_guard():
188+
global_step = _decay_step_counter()
185189

186-
div_res = global_step / decay_steps
187-
if staircase:
188-
div_res = ops.floor(div_res)
190+
div_res = global_step / decay_steps
191+
if staircase:
192+
div_res = ops.floor(div_res)
189193

190-
decayed_lr = learning_rate / (1 + decay_rate * div_res)
194+
decayed_lr = learning_rate / (1 + decay_rate * div_res)
191195

192-
return decayed_lr
196+
return decayed_lr
193197

194198

195199
def polynomial_decay(learning_rate,
@@ -220,25 +224,28 @@ def polynomial_decay(learning_rate,
220224
Returns:
221225
Variable: The decayed learning rate
222226
"""
223-
global_step = _decay_step_counter()
224-
225-
if cycle:
226-
div_res = ops.ceil(global_step / decay_steps)
227-
zero_var = tensor.fill_constant(shape=[1], dtype='float32', value=0.0)
228-
one_var = tensor.fill_constant(shape=[1], dtype='float32', value=1.0)
229-
230-
with control_flow.Switch() as switch:
231-
with switch.case(global_step == zero_var):
232-
tensor.assign(input=one_var, output=div_res)
233-
decay_steps = decay_steps * div_res
234-
else:
235-
decay_steps_var = tensor.fill_constant(
236-
shape=[1], dtype='float32', value=float(decay_steps))
237-
global_step = ops.elementwise_min(x=global_step, y=decay_steps_var)
227+
with default_main_program()._lr_schedule_guard():
228+
global_step = _decay_step_counter()
229+
230+
if cycle:
231+
div_res = ops.ceil(global_step / decay_steps)
232+
zero_var = tensor.fill_constant(
233+
shape=[1], dtype='float32', value=0.0)
234+
one_var = tensor.fill_constant(
235+
shape=[1], dtype='float32', value=1.0)
236+
237+
with control_flow.Switch() as switch:
238+
with switch.case(global_step == zero_var):
239+
tensor.assign(input=one_var, output=div_res)
240+
decay_steps = decay_steps * div_res
241+
else:
242+
decay_steps_var = tensor.fill_constant(
243+
shape=[1], dtype='float32', value=float(decay_steps))
244+
global_step = ops.elementwise_min(x=global_step, y=decay_steps_var)
238245

239-
decayed_lr = (learning_rate - end_learning_rate) * \
240-
((1 - global_step / decay_steps) ** power) + end_learning_rate
241-
return decayed_lr
246+
decayed_lr = (learning_rate - end_learning_rate) * \
247+
((1 - global_step / decay_steps) ** power) + end_learning_rate
248+
return decayed_lr
242249

243250

244251
def piecewise_decay(boundaries, values):
@@ -266,34 +273,36 @@ def piecewise_decay(boundaries, values):
266273
267274
268275
"""
276+
with default_main_program()._lr_schedule_guard():
277+
if len(values) - len(boundaries) != 1:
278+
raise ValueError("len(values) - len(boundaries) should be 1")
269279

270-
if len(values) - len(boundaries) != 1:
271-
raise ValueError("len(values) - len(boundaries) should be 1")
272-
273-
global_step = _decay_step_counter()
280+
global_step = _decay_step_counter()
274281

275-
lr = tensor.create_global_var(
276-
shape=[1],
277-
value=0.0,
278-
dtype='float32',
279-
persistable=True,
280-
name="learning_rate")
282+
lr = tensor.create_global_var(
283+
shape=[1],
284+
value=0.0,
285+
dtype='float32',
286+
persistable=True,
287+
name="learning_rate")
281288

282-
with control_flow.Switch() as switch:
283-
for i in range(len(boundaries)):
284-
boundary_val = tensor.fill_constant(
289+
with control_flow.Switch() as switch:
290+
for i in range(len(boundaries)):
291+
boundary_val = tensor.fill_constant(
292+
shape=[1],
293+
dtype='float32',
294+
value=float(boundaries[i]),
295+
force_cpu=True)
296+
value_var = tensor.fill_constant(
297+
shape=[1], dtype='float32', value=float(values[i]))
298+
with switch.case(global_step < boundary_val):
299+
tensor.assign(value_var, lr)
300+
last_value_var = tensor.fill_constant(
285301
shape=[1],
286302
dtype='float32',
287-
value=float(boundaries[i]),
288-
force_cpu=True)
289-
value_var = tensor.fill_constant(
290-
shape=[1], dtype='float32', value=float(values[i]))
291-
with switch.case(global_step < boundary_val):
292-
tensor.assign(value_var, lr)
293-
last_value_var = tensor.fill_constant(
294-
shape=[1], dtype='float32', value=float(values[len(values) - 1]))
295-
with switch.default():
296-
tensor.assign(last_value_var, lr)
303+
value=float(values[len(values) - 1]))
304+
with switch.default():
305+
tensor.assign(last_value_var, lr)
297306

298307
return lr
299308

python/paddle/fluid/tests/unittests/test_dist_mnist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def _setup_config(self):
2222
self._sync_mode = True
2323
self._use_reduce = False
2424

25-
def test_se_resnext(self):
25+
def test_dist_train(self):
2626
self.check_with_place("dist_mnist.py", delta=1e-7)
2727

2828

@@ -31,7 +31,7 @@ def _setup_config(self):
3131
self._sync_mode = True
3232
self._mem_opt = True
3333

34-
def test_se_resnext(self):
34+
def test_dist_train(self):
3535
self.check_with_place("dist_mnist.py", delta=1e-7)
3636

3737

@@ -40,7 +40,7 @@ def _setup_config(self):
4040
self._sync_mode = False
4141
self._use_reduce = False
4242

43-
def test_se_resnext(self):
43+
def test_dist_train(self):
4444
self.check_with_place("dist_mnist.py", delta=200)
4545

4646

0 commit comments

Comments
 (0)