@@ -51,49 +51,6 @@ class AdaDeltaUpdate
5151 // Nothing to do.
5252 }
5353
54- /* *
55- * The Initialize method is called by SGD Optimizer method before the start of
56- * the iteration update process. In AdaDelta update policy, the mean squared
57- * and the delta mean squared gradient matrices are initialized to the zeros
58- * matrix with the same size as gradient matrix (see ens::SGD<>).
59- *
60- * @param rows Number of rows in the gradient matrix.
61- * @param cols Number of columns in the gradient matrix.
62- */
63- void Initialize (const size_t rows, const size_t cols)
64- {
65- // Initialize empty matrices for mean sum of squares of parameter gradient.
66- meanSquaredGradient = arma::zeros<arma::mat>(rows, cols);
67- meanSquaredGradientDx = arma::zeros<arma::mat>(rows, cols);
68- }
69-
70- /* *
71- * Update step for SGD. The AdaDelta update dynamically adapts over time using
72- * only first order information. Additionally, AdaDelta requires no manual
73- * tuning of a learning rate.
74- *
75- * @param iterate Parameters that minimize the function.
76- * @param stepSize Step size to be used for the given iteration.
77- * @param gradient The gradient matrix.
78- */
79- void Update (arma::mat& iterate,
80- const double stepSize,
81- const arma::mat& gradient)
82- {
83- // Accumulate gradient.
84- meanSquaredGradient *= rho;
85- meanSquaredGradient += (1 - rho) * (gradient % gradient);
86- arma::mat dx = arma::sqrt ((meanSquaredGradientDx + epsilon) /
87- (meanSquaredGradient + epsilon)) % gradient;
88-
89- // Accumulate updates.
90- meanSquaredGradientDx *= rho;
91- meanSquaredGradientDx += (1 - rho) * (dx % dx);
92-
93- // Apply update.
94- iterate -= (stepSize * dx);
95- }
96-
9754 // ! Get the smoothing parameter.
9855 double Rho () const { return rho; }
9956 // ! Modify the smoothing parameter.
@@ -104,18 +61,77 @@ class AdaDeltaUpdate
10461 // ! Modify the value used to initialise the mean squared gradient parameter.
10562 double & Epsilon () { return epsilon; }
10663
64+ /* *
65+ * The UpdatePolicyType policy classes must contain an internal 'Policy'
66+ * template class with two template arguments: MatType and GradType. This is
67+ * instantiated at the start of the optimization, and holds parameters
68+ * specific to an individual optimization.
69+ */
70+ template <typename MatType, typename GradType>
71+ class Policy
72+ {
73+ public:
74+ /* *
75+ * This constructor is called by the SGD optimizer method before the start
76+ * of the iteration update process. In AdaDelta update policy, the mean
77+ * squared and the delta mean squared gradient matrices are initialized to
78+ * the zeros matrix with the same size as gradient matrix (see ens::SGD<>).
79+ *
80+ * @param parent AdaDeltaUpdate object.
81+ * @param rows Number of rows in the gradient matrix.
82+ * @param cols Number of columns in the gradient matrix.
83+ */
84+ Policy (AdaDeltaUpdate& parent, const size_t rows, const size_t cols) :
85+ parent (parent)
86+ {
87+ meanSquaredGradient.zeros (rows, cols);
88+ meanSquaredGradientDx.zeros (rows, cols);
89+ }
90+
91+ /* *
92+ * Update step for SGD. The AdaDelta update dynamically adapts over time
93+ * using only first order information. Additionally, AdaDelta requires no
94+ * manual tuning of a learning rate.
95+ *
96+ * @param iterate Parameters that minimize the function.
97+ * @param stepSize Step size to be used for the given iteration.
98+ * @param gradient The gradient matrix.
99+ */
100+ void Update (MatType& iterate,
101+ const double stepSize,
102+ const GradType& gradient)
103+ {
104+ // Accumulate gradient.
105+ meanSquaredGradient *= parent.rho ;
106+ meanSquaredGradient += (1 - parent.rho ) * (gradient % gradient);
107+ GradType dx = arma::sqrt ((meanSquaredGradientDx + parent.epsilon ) /
108+ (meanSquaredGradient + parent.epsilon )) % gradient;
109+
110+ // Accumulate updates.
111+ meanSquaredGradientDx *= parent.rho ;
112+ meanSquaredGradientDx += (1 - parent.rho ) * (dx % dx);
113+
114+ // Apply update.
115+ iterate -= (stepSize * dx);
116+ }
117+
118+ private:
119+ // The instantiated parent class.
120+ AdaDeltaUpdate& parent;
121+
122+ // The mean squared gradient matrix.
123+ GradType meanSquaredGradient;
124+
125+ // The delta mean squared gradient matrix.
126+ GradType meanSquaredGradientDx;
127+ };
128+
107129 private:
108130 // The smoothing parameter.
109131 double rho;
110132
111133 // The epsilon value used to initialise the mean squared gradient parameter.
112134 double epsilon;
113-
114- // The mean squared gradient matrix.
115- arma::mat meanSquaredGradient;
116-
117- // The delta mean squared gradient matrix.
118- arma::mat meanSquaredGradientDx;
119135};
120136
121137} // namespace ens
0 commit comments