Skip to content

Commit 6e03f79

Browse files
authored
Add centered mode rmsprop (#13161)
* rmsprop optimizer support v1 mode * typo * optimize code * refine code * optimize unit test * update test_rmsprop_op.py * update formula of rmsprop * optimize document * update API.spec for RMSPropOptimizer * add default value to check_output_with_place equal_nan
1 parent 9df2d8b commit 6e03f79

File tree

6 files changed

+233
-97
lines changed

6 files changed

+233
-97
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ paddle.fluid.optimizer.DecayedAdagradOptimizer.__init__ ArgSpec(args=['self', 'l
376376
paddle.fluid.optimizer.DecayedAdagradOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
377377
paddle.fluid.optimizer.FtrlOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'l1', 'l2', 'lr_power'], varargs=None, keywords='kwargs', defaults=(0.0, 0.0, -0.5))
378378
paddle.fluid.optimizer.FtrlOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
379-
paddle.fluid.optimizer.RMSPropOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'rho', 'epsilon', 'momentum'], varargs=None, keywords='kwargs', defaults=(0.95, 1e-06, 0.0))
379+
paddle.fluid.optimizer.RMSPropOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'rho', 'epsilon', 'momentum', 'centered'], varargs=None, keywords='kwargs', defaults=(0.95, 1e-06, 0.0, False))
380380
paddle.fluid.optimizer.RMSPropOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
381381
paddle.fluid.optimizer.AdadeltaOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'epsilon', 'rho'], varargs=None, keywords='kwargs', defaults=(1e-06, 0.95))
382382
paddle.fluid.optimizer.AdadeltaOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))

paddle/fluid/operators/rmsprop_op.cc

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,13 @@ class RmspropOp : public framework::OperatorWithKernel {
3636
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
3737
"Output(param_out) of RmspropOp should not be null.");
3838
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
39-
"Output(Momentum_out) of RmspropOp should not be null.");
39+
"Output(MomentOut) of RmspropOp should not be null.");
4040
PADDLE_ENFORCE(ctx->HasOutput("MeanSquareOut"),
4141
"Output(MeanSquareOut) of RmspropOp should not be null.");
42+
if (ctx->Attrs().Get<bool>("centered")) {
43+
PADDLE_ENFORCE(ctx->HasOutput("MeanGradOut"),
44+
"Output(MeanGradOut) of RmspropOp should not be null.");
45+
}
4246

4347
auto param_dim = ctx->GetInputDim("Param");
4448
PADDLE_ENFORCE_EQ(
@@ -58,6 +62,9 @@ class RmspropOp : public framework::OperatorWithKernel {
5862
ctx->SetOutputDim("ParamOut", param_dim);
5963
ctx->SetOutputDim("MomentOut", param_dim);
6064
ctx->SetOutputDim("MeanSquareOut", param_dim);
65+
if (ctx->Attrs().Get<bool>("centered")) {
66+
ctx->SetOutputDim("MeanGradOut", param_dim);
67+
}
6168
}
6269
};
6370

@@ -70,6 +77,10 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
7077
AddInput("MeanSquare",
7178
"(Tensor, default Tensor<float>)"
7279
" The mean square value that gets updated.");
80+
AddInput("MeanGrad",
81+
"(Tensor, default Tensor<float>)"
82+
" The moving average of gradient")
83+
.AsDispensable();
7384
AddInput("LearningRate",
7485
"(Tensor, default Tensor<float>) "
7586
"The learning rate should be a tensor of size 1.");
@@ -82,6 +93,8 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
8293
AddOutput("ParamOut", "(Tensor) Output updated parameter value.");
8394
AddOutput("MomentOut", "(Tensor) Output updated moment.");
8495
AddOutput("MeanSquareOut", "(Tensor) Output Mean squared updated value.");
96+
AddOutput("MeanGradOut",
97+
"(Tensor) Output moving average of gradient updated value.");
8598

8699
AddAttr<float>("epsilon",
87100
"(float, default 1e-10) Constant "
@@ -93,6 +106,8 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
93106
.SetDefault(0.9f);
94107
AddAttr<float>("momentum", "(float, default 0.0) Constant value.")
95108
.SetDefault(0.0f);
109+
AddAttr<bool>("centered", "(bool, default false) use centered rmsprop.")
110+
.SetDefault(false);
96111
AddComment(R"DOC(
97112
Rmsprop Optimizer.
98113
@@ -103,6 +118,14 @@ MomentOut = momentum * Moment +
103118
ParamOut = Param - MomentOut
104119
$$
105120
121+
if centered is true:
122+
123+
mean_grad = decay * mean_square{t-1} + (1-decay) * gradient
124+
mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2
125+
mom = momentum * mom{t-1} + learning_rate * g_t /
126+
sqrt(mean_square - mean_grad**2 + epsilon)
127+
param -= mom
128+
106129
The original slides that proposed Rmsprop: Slide 29 of
107130
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
108131

paddle/fluid/operators/rmsprop_op.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class RmspropOpKernel : public framework::OpKernel<T> {
4141
float epsilon = ctx.Attr<float>("epsilon");
4242
float rho = ctx.Attr<float>("decay");
4343
float momentum = ctx.Attr<float>("momentum");
44+
bool centered = ctx.Attr<bool>("centered");
4445

4546
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
4647
auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare"));
@@ -53,12 +54,24 @@ class RmspropOpKernel : public framework::OpKernel<T> {
5354
auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
5455
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
5556

56-
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
57+
Eigen::DSizes<int, 1> grad_dsize(static_cast<int>(grad->numel()));
5758

5859
ms_out.device(place) = rho * ms + (1 - rho) * g * g;
59-
mom_out.device(place) =
60-
momentum * mom +
61-
lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt();
60+
if (centered) {
61+
auto mg = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanGrad"));
62+
auto* mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
63+
mean_grad_out->mutable_data<T>(ctx.GetPlace());
64+
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
65+
66+
mg_out.device(place) = rho * mg + (1 - rho) * g;
67+
mom_out.device(place) = momentum * mom +
68+
lr.broadcast(grad_dsize) * g /
69+
(ms_out - mg_out.square() + epsilon).sqrt();
70+
} else {
71+
mom_out.device(place) =
72+
momentum * mom +
73+
lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt();
74+
}
6275
p_out.device(place) = p - mom_out;
6376
}
6477
};

python/paddle/fluid/optimizer.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,20 @@ class RMSPropOptimizer(Optimizer):
897897
898898
r(w, t) & = \\rho r(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2
899899
900-
v(w, t) & = \\beta v(w, t-1) + \\frac{\\eta} {\\sqrt{v(w,t) +
900+
v(w, t) & = \\beta v(w, t-1) + \\frac{\\eta} {\\sqrt{r(w,t) +
901+
\\epsilon}} \\nabla Q_{i}(w)
902+
903+
w & = w - v(w, t)
904+
905+
if centered is True:
906+
907+
.. math::
908+
909+
r(w, t) & = \\rho r(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2
910+
911+
g(w, t) & = \\rho g(w, t-1) + (1 - \\rho)\\nabla Q_{i}(w)
912+
913+
v(w, t) & = \\beta v(w, t-1) + \\frac{\\eta} {\\sqrt{r(w,t) - (g(w, t))^2 +
901914
\\epsilon}} \\nabla Q_{i}(w)
902915
903916
w & = w - v(w, t)
@@ -915,6 +928,10 @@ class RMSPropOptimizer(Optimizer):
915928
avoid division by zero, set 1e-6 by default.
916929
momentum(float): :math:`\\beta` in equation is the momentum term,
917930
set 0.0 by default.
931+
centered(bool): If True, gradients are normalized by the estimated variance of
932+
the gradient; if False, by the uncentered second moment. Setting this to
933+
True may help with training, but is slightly more expensive in terms of
934+
computation and memory. Defaults to False.
918935
919936
Raises:
920937
ValueError: If learning_rate, rho, epsilon, momentum are None.
@@ -928,12 +945,14 @@ class RMSPropOptimizer(Optimizer):
928945

929946
_momentum_acc_str = "momentum"
930947
_mean_square_acc_str = "mean_square"
948+
_mean_grad_acc_str = "mean_grad"
931949

932950
def __init__(self,
933951
learning_rate,
934952
rho=0.95,
935953
epsilon=1.0e-6,
936954
momentum=0.0,
955+
centered=False,
937956
**kwargs):
938957
super(RMSPropOptimizer, self).__init__(
939958
learning_rate=learning_rate, **kwargs)
@@ -950,6 +969,7 @@ def __init__(self,
950969
self._rho = rho
951970
self._epsilon = epsilon
952971
self._momentum = momentum
972+
self._centered = centered
953973

954974
def _create_accumulators(self, block, parameters):
955975
if not isinstance(block, framework.Block):
@@ -958,6 +978,7 @@ def _create_accumulators(self, block, parameters):
958978
for p in parameters:
959979
self._add_accumulator(self._momentum_acc_str, p)
960980
self._add_accumulator(self._mean_square_acc_str, p)
981+
self._add_accumulator(self._mean_grad_acc_str, p)
961982

962983
def _append_optimize_op(self, block, param_and_grad):
963984
if not isinstance(block, framework.Block):
@@ -967,24 +988,29 @@ def _append_optimize_op(self, block, param_and_grad):
967988
param_and_grad[0])
968989
mean_square_acc = self._get_accumulator(self._mean_square_acc_str,
969990
param_and_grad[0])
991+
mean_grad_acc = self._get_accumulator(self._mean_grad_acc_str,
992+
param_and_grad[0])
970993
rmsprop_op = block.append_op(
971994
type=self.type,
972995
inputs={
973996
"Param": param_and_grad[0],
974997
"Grad": param_and_grad[1],
975998
"Moment": momentum_acc,
976999
"MeanSquare": mean_square_acc,
1000+
"MeanGrad": mean_grad_acc,
9771001
"LearningRate": self._create_param_lr(param_and_grad),
9781002
},
9791003
outputs={
9801004
"ParamOut": param_and_grad[0],
9811005
"MomentOut": momentum_acc,
982-
"MeanSquareOut": mean_square_acc
1006+
"MeanSquareOut": mean_square_acc,
1007+
"MeanGradOut": mean_grad_acc
9831008
},
9841009
attrs={
9851010
"epsilon": self._epsilon,
9861011
"decay": self._rho,
987-
"momentum": self._momentum
1012+
"momentum": self._momentum,
1013+
"centered": self._centered
9881014
})
9891015

9901016
return rmsprop_op

python/paddle/fluid/tests/unittests/op_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def _calc_output(self, place, parallel=False):
291291
return_numpy=False)
292292
return outs, fetch_list
293293

294-
def check_output_with_place(self, place, atol):
294+
def check_output_with_place(self, place, atol, equal_nan=False):
295295
outs, fetch_list = self._calc_output(place)
296296
for out_name, out_dup in Operator.get_op_outputs(self.op_type):
297297
if out_name not in self.outputs:
@@ -321,7 +321,7 @@ def find_actual(target_name, fetch_list):
321321
if isinstance(expect, tuple) else expect
322322
self.assertTrue(
323323
np.allclose(
324-
actual_t, expect_t, atol=atol),
324+
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
325325
"Output (" + sub_out_name + ") has diff at " +
326326
str(place))
327327
if isinstance(expect, tuple):
@@ -337,7 +337,7 @@ def find_actual(target_name, fetch_list):
337337
expect_t = expect[0] if isinstance(expect, tuple) else expect
338338
self.assertTrue(
339339
np.allclose(
340-
actual_t, expect_t, atol=atol),
340+
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
341341
"Output (" + out_name + ") has diff at " + str(place) +
342342
"\nExpect " + str(expect_t) + "\n" + "But Got" +
343343
str(actual_t))
@@ -360,10 +360,10 @@ def _get_places(self):
360360
places.append(core.CUDAPlace(0))
361361
return places
362362

363-
def check_output(self, atol=1e-5):
363+
def check_output(self, atol=1e-5, equal_nan=False):
364364
places = self._get_places()
365365
for place in places:
366-
self.check_output_with_place(place, atol)
366+
self.check_output_with_place(place, atol, equal_nan)
367367

368368
def check_output_customized(self, checker):
369369
places = self._get_places()

0 commit comments

Comments
 (0)