Skip to content

Commit d3da0ef

Browse files
authored
Fix dist train with rmsprop (#12649)
* fix dist train with rmsprop * add rmsprop transpiler test * update by comment
1 parent 989cae2 commit d3da0ef

File tree

2 files changed

+58
-7
lines changed

2 files changed

+58
-7
lines changed

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,5 +536,35 @@ def transpiler_test_impl(self):
536536
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
537537

538538

539+
class TestRMSPropOptimizer(TranspilerTest):
540+
def net_conf(self):
541+
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
542+
y_predict = fluid.layers.fc(input=x,
543+
size=1000,
544+
act=None,
545+
param_attr=fluid.ParamAttr(name='fc_w'),
546+
bias_attr=fluid.ParamAttr(name='fc_b'))
547+
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
548+
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
549+
avg_cost = fluid.layers.mean(cost)
550+
optimizer = fluid.optimizer.RMSProp(learning_rate=0.1)
551+
optimizer.minimize(avg_cost)
552+
return
553+
554+
def transpiler_test_impl(self):
555+
pserver, startup = self.get_pserver(self.pserver1_ep)
556+
pserver2, startup2 = self.get_pserver(self.pserver2_ep)
557+
558+
self.assertEqual(len(pserver.blocks), 3)
559+
# block1~2: optimize pass
560+
self.assertEqual([op.type for op in pserver.blocks[1].ops],
561+
["sum", "scale", "rmsprop"])
562+
# the variable #fc_w will be split into two blocks
563+
fc_w_var = startup.global_block().var("fc_w.block1")
564+
self.assertEqual(fc_w_var.shape, (500, 1000))
565+
moment_var = startup.global_block().var("momentum_1")
566+
self.assertEqual(moment_var.shape, (500, 1000))
567+
568+
539569
if __name__ == "__main__":
540570
unittest.main()

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,18 +1182,39 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
11821182
program = optimize_block.program
11831183
pserver_block = program.global_block()
11841184
new_inputs = dict()
1185+
11851186
# update param/grad shape first, then other inputs like
11861187
# moment can use the updated shape
1188+
def _get_param_block(opt_op):
1189+
# param is already created on global program
1190+
param_block = None
1191+
for p in self.param_grad_ep_mapping[endpoint]["params"]:
1192+
if same_or_split_var(p.name, opt_op.input("Param")[0]):
1193+
param_block = p
1194+
break
1195+
return param_block
1196+
11871197
for key in opt_op.input_names:
11881198
if key == "Grad":
11891199
new_inputs[key] = merged_var
1200+
# For RMSProp optimizer
1201+
elif key == "Moment" or key == "MeanSquare":
1202+
param_block = _get_param_block(opt_op)
1203+
if not param_block:
1204+
return
1205+
moment_var = origin_program.global_block().vars[opt_op.input(
1206+
key)[0]]
1207+
tmpvar = pserver_block.create_var(
1208+
name=moment_var.name,
1209+
persistable=moment_var.persistable,
1210+
dtype=moment_var.dtype,
1211+
# change to use same shape as param
1212+
# TODO(typhoonzero): didn't append .block in the var name,
1213+
# may affect checkpoint saving? Need to verify.
1214+
shape=param_block.shape)
1215+
new_inputs[key] = tmpvar
11901216
elif key == "Param":
1191-
# param is already created on global program
1192-
param_block = None
1193-
for p in self.param_grad_ep_mapping[endpoint]["params"]:
1194-
if same_or_split_var(p.name, opt_op.input(key)[0]):
1195-
param_block = p
1196-
break
1217+
param_block = _get_param_block(opt_op)
11971218
if not param_block:
11981219
return
11991220
tmpvar = pserver_block.create_var(
@@ -1219,7 +1240,7 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
12191240

12201241
for key in opt_op.input_names:
12211242
new_shape = None
1222-
if key in ["Param", "Grad", "LearningRate"]:
1243+
if key in ["Param", "Grad", "LearningRate", "Moment", "MeanSquare"]:
12231244
continue
12241245
var = self.origin_program.global_block().vars[opt_op.input(key)[0]]
12251246
# update accumulator variable shape

0 commit comments

Comments
 (0)