@@ -96,6 +96,8 @@ class AdaBoundUpdate
9696 class Policy
9797 {
9898 public:
99+ typedef typename MatType::elem_type ElemType;
100+
99101 /* *
100102 * This constructor is called by the SGD Optimize() method before the start
101103 * of the iteration update process.
@@ -105,10 +107,24 @@ class AdaBoundUpdate
105107 * @param cols Number of columns in the gradient matrix.
106108 */
107109 Policy (AdaBoundUpdate& parent, const size_t rows, const size_t cols) :
108- parent (parent), first(true ), initialStepSize(0 ), iteration(0 )
110+ parent (parent),
111+ finalLr (ElemType(parent.finalLr)),
112+ gamma (ElemType(parent.gamma)),
113+ epsilon (ElemType(parent.epsilon)),
114+ beta1 (ElemType(parent.beta1)),
115+ beta2 (ElemType(parent.beta2)),
116+ first (true ),
117+ initialStepSize (0 ),
118+ iteration (0 )
109119 {
110120 m.zeros (rows, cols);
111121 v.zeros (rows, cols);
122+
123+ // Check for underflows in conversions.
124+ if (gamma == ElemType (0 ) && parent.gamma != 0.0 )
125+ gamma = 10 * std::numeric_limits<ElemType>::epsilon ();
126+ if (epsilon == ElemType (0 ) && parent.epsilon != 0.0 )
127+ epsilon = 10 * std::numeric_limits<ElemType>::epsilon ();
112128 }
113129
114130 /* *
@@ -129,30 +145,30 @@ class AdaBoundUpdate
129145 if (first)
130146 {
131147 first = false ;
132- initialStepSize = stepSize;
148+ initialStepSize = ElemType ( stepSize) ;
133149 }
134150
135151 // Increment the iteration counter variable.
136152 ++iteration;
137153
138154 // Decay the first and second moment running average coefficient.
139- m *= parent. beta1 ;
140- m += (1 - parent. beta1 ) * gradient;
155+ m *= beta1;
156+ m += (1 - beta1) * gradient;
141157
142- v *= parent. beta2 ;
143- v += (1 - parent. beta2 ) * (gradient % gradient);
158+ v *= beta2;
159+ v += (1 - beta2) * (gradient % gradient);
144160
145- const ElemType biasCorrection1 = 1.0 - std::pow (parent. beta1 , iteration);
146- const ElemType biasCorrection2 = 1.0 - std::pow (parent. beta2 , iteration);
161+ const ElemType biasCorrection1 = 1 - std::pow (beta1, ElemType ( iteration) );
162+ const ElemType biasCorrection2 = 1 - std::pow (beta2, ElemType ( iteration) );
147163
148- const ElemType fl = parent. finalLr * stepSize / initialStepSize;
149- const ElemType lower = fl * (1.0 - 1.0 / (parent. gamma * iteration + 1 ));
150- const ElemType upper = fl * (1.0 + 1.0 / (parent. gamma * iteration));
164+ const ElemType fl = finalLr * ElemType ( stepSize) / initialStepSize;
165+ const ElemType lower = fl * (1 - 1 / (gamma * iteration + 1 ));
166+ const ElemType upper = fl * (1 + 1 / (gamma * iteration));
151167
152- // Applies bounds on actual learning rate.
153- iterate -= arma:: clamp ((stepSize *
154- std::sqrt (biasCorrection2) / biasCorrection1) / (arma:: sqrt (v) +
155- parent. epsilon ), lower, upper) % m;
168+ // Applies bounds on actual learning rate.
169+ iterate -= clamp ((ElemType ( stepSize) *
170+ std::sqrt (biasCorrection2) / biasCorrection1) / (sqrt (v) + epsilon),
171+ lower, upper) % m;
156172 }
157173
158174 private:
@@ -165,11 +181,18 @@ class AdaBoundUpdate
165181 // The exponential moving average of squared gradient values.
166182 GradType v;
167183
184+ // Parameters of the parent, casted to the element type of the problem.
185+ ElemType finalLr;
186+ ElemType gamma;
187+ ElemType epsilon;
188+ ElemType beta1;
189+ ElemType beta2;
190+
168191 // Whether this is the first call of the Update method.
169192 bool first;
170193
171194 // The initial (Adam) learning rate.
172- double initialStepSize;
195+ ElemType initialStepSize;
173196
174197 // The number of iterations.
175198 size_t iteration;
0 commit comments