Skip to content

Commit c36bf19

Browse files
committed
add optimizer doc
1 parent 929090e commit c36bf19

File tree

2 files changed

+131
-2
lines changed

2 files changed

+131
-2
lines changed

doc/api/v2/config/optimizer.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
.. _api_v2.optimizer:
2-
31
==========
42
Optimizer
53
==========

python/paddle/v2/optimizer.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,35 @@ def create_remote_updater(self, pass_num):
4747

4848

4949
class Momentum(Optimizer):
50+
"""
51+
SGD Optimizer.
52+
53+
SGD is an optimization method, trying to find a neural network that
54+
minimize the "cost/error" of it by iteration. In paddle's implementation
55+
SGD Optimizer is synchronized, which means all gradients will be wait to
56+
calculate and reduced into one gradient, then do optimize operation.
57+
58+
The neural network consider the learning problem of minimizing an objective
59+
function, that has the form of a sum
60+
61+
.. math::
62+
63+
Q(w) = \\sum_{i}^{n} Q_i(w)
64+
65+
The value of function Q sometimes is the cost of neural network (Mean
66+
Square Error between prediction and label for example). The function Q is
67+
parametrised by w, the weight/bias of neural network. And weights is what to
68+
be learned. The i is the i-th observation in (trainning) data.
69+
70+
So, the SGD method will optimize the weight by
71+
72+
.. math::
73+
74+
w = w - \\eta \\nabla Q(w) = w - \\eta \\sum_{i}^{n} \\nabla Q_i(w)
75+
76+
where :math:`\\eta` is learning rate. And :math:`n` is batch size.
77+
"""
78+
5079
def __init__(self, momentum=None, sparse=False, **kwargs):
5180
learning_method = v1_optimizers.MomentumOptimizer(
5281
momentum=momentum, sparse=sparse)
@@ -55,26 +84,92 @@ def __init__(self, momentum=None, sparse=False, **kwargs):
5584

5685

5786
class Adam(Optimizer):
87+
"""
88+
Adam optimizer.
89+
The details of please refer `Adam: A Method for Stochastic Optimization
90+
<https://arxiv.org/abs/1412.6980>`_
91+
92+
.. math::
93+
94+
m(w, t) & = \\beta_1 m(w, t-1) + (1 - \\beta_1) \\nabla Q_i(w) \\\\
95+
v(w, t) & = \\beta_2 v(w, t-1) + (1 - \\beta_2)(\\nabla Q_i(w)) ^2 \\\\
96+
w & = w - \\frac{\\eta}{\\sqrt{v(w,t) + \\epsilon}}
97+
98+
:param beta1: the :math:`\\beta_1` in equation.
99+
:type beta1: float
100+
:param beta2: the :math:`\\beta_2` in equation.
101+
:type beta2: float
102+
:param epsilon: the :math:`\\epsilon` in equation. It is used to prevent
103+
divided by zero.
104+
:type epsilon: float
105+
"""
106+
58107
def __init__(self, beta1=0.9, beta2=0.999, epsilon=1e-8, **kwargs):
59108
learning_method = v1_optimizers.AdamOptimizer(
60109
beta1=beta1, beta2=beta2, epsilon=epsilon)
61110
super(Adam, self).__init__(learning_method=learning_method, **kwargs)
62111

63112

64113
class Adamax(Optimizer):
114+
"""
115+
Adamax optimizer.
116+
117+
The details of please refer this `Adam: A Method for Stochastic Optimization
118+
<https://arxiv.org/abs/1412.6980>`_
119+
120+
.. math::
121+
122+
m_t & = \\beta_1 * m_{t-1} + (1-\\beta_1)* \\nabla Q_i(w) \\\\
123+
u_t & = max(\\beta_2*u_{t-1}, abs(\\nabla Q_i(w))) \\\\
124+
w_t & = w_{t-1} - (\\eta/(1-\\beta_1^t))*m_t/u_t
125+
126+
:param beta1: the :math:`\\beta_1` in the equation.
127+
:type beta1: float
128+
:param beta2: the :math:`\\beta_2` in the equation.
129+
:type beta2: float
130+
"""
131+
65132
def __init__(self, beta1=0.9, beta2=0.999, **kwargs):
66133
learning_method = v1_optimizers.AdamaxOptimizer(
67134
beta1=beta1, beta2=beta2)
68135
super(Adamax, self).__init__(learning_method=learning_method, **kwargs)
69136

70137

71138
class AdaGrad(Optimizer):
139+
"""
140+
Adagrad(for ADAptive GRAdient algorithm) optimizer.
141+
142+
For details please refer this `Adaptive Subgradient Methods for
143+
Online Learning and Stochastic Optimization
144+
<http://www.magicbroom.info/Papers/DuchiHaSi10.pdf>`_.
145+
146+
.. math::
147+
148+
G &= \\sum_{\\tau=1}^{t} g_{\\tau} g_{\\tau}^T \\\\
149+
w & = w - \\eta diag(G)^{-\\frac{1}{2}} \\circ g
150+
"""
151+
72152
def __init__(self, **kwargs):
73153
learning_method = v1_optimizers.AdaGradOptimizer()
74154
super(AdaGrad, self).__init__(learning_method=learning_method, **kwargs)
75155

76156

77157
class DecayedAdaGrad(Optimizer):
158+
"""
159+
AdaGrad method with decayed sum gradients. The equations of this method
160+
show as follow.
161+
162+
.. math::
163+
164+
E(g_t^2) &= \\rho * E(g_{t-1}^2) + (1-\\rho) * g^2 \\\\
165+
learning\\_rate &= 1/sqrt( ( E(g_t^2) + \\epsilon )
166+
167+
:param rho: The :math:`\\rho` parameter in that equation
168+
:type rho: float
169+
:param epsilon: The :math:`\\epsilon` parameter in that equation.
170+
:type epsilon: float
171+
"""
172+
78173
def __init__(self, rho=0.95, epsilon=1e-06, **kwargs):
79174
learning_method = v1_optimizers.DecayedAdaGradOptimizer(
80175
rho=rho, epsilon=epsilon)
@@ -83,6 +178,24 @@ def __init__(self, rho=0.95, epsilon=1e-06, **kwargs):
83178

84179

85180
class AdaDelta(Optimizer):
181+
"""
182+
AdaDelta method. The details of adadelta please refer to this
183+
`ADADELTA: AN ADAPTIVE LEARNING RATE METHOD
184+
<http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf>`_.
185+
186+
.. math::
187+
188+
E(g_t^2) &= \\rho * E(g_{t-1}^2) + (1-\\rho) * g^2 \\\\
189+
learning\\_rate &= sqrt( ( E(dx_{t-1}^2) + \\epsilon ) / ( \\
190+
E(g_t^2) + \\epsilon ) ) \\\\
191+
E(dx_t^2) &= \\rho * E(dx_{t-1}^2) + (1-\\rho) * (-g*learning\\_rate)^2
192+
193+
:param rho: :math:`\\rho` in equation
194+
:type rho: float
195+
:param epsilon: :math:`\\rho` in equation
196+
:type epsilon: float
197+
"""
198+
86199
def __init__(self, rho=0.95, epsilon=1e-06, **kwargs):
87200
learning_method = v1_optimizers.AdaDeltaOptimizer(
88201
rho=rho, epsilon=epsilon)
@@ -91,6 +204,24 @@ def __init__(self, rho=0.95, epsilon=1e-06, **kwargs):
91204

92205

93206
class RMSProp(Optimizer):
207+
"""
208+
RMSProp(for Root Mean Square Propagation) optimizer. For details please
209+
refer this `slide <http://www.cs.toronto.edu/~tijmen/csc321/slides/
210+
lecture_slides_lec6.pdf>`_.
211+
212+
The equations of this method as follows:
213+
214+
.. math::
215+
216+
v(w, t) & = \\rho v(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2 \\\\
217+
w & = w - \\frac{\\eta} {\\sqrt{v(w,t) + \\epsilon}} \\nabla Q_{i}(w)
218+
219+
:param rho: the :math:`\\rho` in the equation. The forgetting factor.
220+
:type rho: float
221+
:param epsilon: the :math:`\\epsilon` in the equation.
222+
:type epsilon: float
223+
"""
224+
94225
def __init__(self, rho=0.95, epsilon=1e-6, **kwargs):
95226
learning_method = v1_optimizers.RMSPropOptimizer(
96227
rho=rho, epsilon=epsilon)

0 commit comments

Comments
 (0)