11package aima .learning .neural ;
22
3- import java .util .ArrayList ;
4- import java .util .List ;
5-
63import aima .learning .statistics .ActivationFunction ;
74import aima .util .Matrix ;
85import aima .util .Util ;
@@ -17,7 +14,7 @@ public class Layer {
1714
1815 private Vector lastActivationValues , lastInducedField ;
1916
20- private Matrix lastSensitivityMatrix ;
17+ private Matrix mySensitivityMatrix ;
2118
2219 private Matrix lastWeightUpdateMatrix ;
2320
@@ -35,8 +32,8 @@ public Layer(Matrix weightMatrix, Vector biasVector, ActivationFunction af) {
3532 weightMatrix .getColumnDimension ());
3633 penultimateWeightUpdateMatrix = new Matrix (weightMatrix
3734 .getRowDimension (), weightMatrix .getColumnDimension ());
38- // sensitivityMatrix = new Matrix(weightMatrix.getRowDimension(),
39- // weightMatrix.getColumnDimension());
35+ mySensitivityMatrix = new Matrix (weightMatrix .getRowDimension (),
36+ weightMatrix .getColumnDimension ());
4037 this .biasVector = biasVector ;
4138 lastBiasUpdateVector = new Vector (biasVector .getRowDimension ());
4239 penultimateBiasUpdateVector = new Vector (biasVector .getRowDimension ());
@@ -91,26 +88,9 @@ public Vector getBiasVector() {
9188 return biasVector ;
9289 }
9390
94- public Matrix sensitivityMatrixFromErrorMatrix (Vector errorVector ) {
95- Matrix derivativeMatrix = createDerivativeMatrix (lastInducedField );
96- Matrix sensitivityMatrix = derivativeMatrix .times (errorVector ).times (
97- -2.0 );
98- lastSensitivityMatrix = sensitivityMatrix .copy ();
99- return sensitivityMatrix ;
100- }
101-
102- public Matrix sensitivityMatrixFromSucceedingLayer (Layer nextLayer ) {
103- Matrix derivativeMatrix = createDerivativeMatrix (lastInducedField );
104- Matrix weightTranspose = nextLayer .weightMatrix .transpose ();
105- Matrix sensitivityMatrix = derivativeMatrix .times (weightTranspose )
106- .times (nextLayer .getSensitivityMatrix ());
107- lastSensitivityMatrix = sensitivityMatrix .copy ();
108- return sensitivityMatrix ;
109- }
110-
111- private Matrix getSensitivityMatrix () {
91+ public Matrix getSensitivityMatrix () {
11292
113- return lastSensitivityMatrix ;
93+ return mySensitivityMatrix ;
11494 }
11595
11696 public int numberOfNeurons () {
@@ -151,88 +131,38 @@ private static void initializeVector(Vector aVector, double lowerLimit,
151131 }
152132 }
153133
154- private Matrix createDerivativeMatrix (Vector lastInducedField ) {
155- List <Double > lst = new ArrayList <Double >();
156- for (int i = 0 ; i < lastInducedField .size (); i ++) {
157- lst .add (new Double (activationFunction .deriv (lastInducedField
158- .getValue (i ))));
159- }
160- return Matrix .createDiagonalMatrix (lst );
161- }
162-
163- public Matrix calculateWeightUpdates (Vector previousLayerActivationOrInput ,
164- double alpha ) {
165- Matrix activationTranspose = previousLayerActivationOrInput .transpose ();
166- Matrix weightUpdateMatrix = lastSensitivityMatrix .times (
167- activationTranspose ).times (alpha ).times (-1.0 );
168- penultimateWeightUpdateMatrix = lastWeightUpdateMatrix .copy ();
169- lastWeightUpdateMatrix = weightUpdateMatrix .copy ();
170- return weightUpdateMatrix ;
171- }
172-
173- public Matrix calculateWeightUpdates (Vector previousLayerActivationOrInput ,
174- double alpha , double momentum ) {
175- Matrix activationTranspose = previousLayerActivationOrInput .transpose ();
176- Matrix momentumLessUpdate = lastSensitivityMatrix .times (
177- activationTranspose ).times (alpha ).times (-1.0 );
178- Matrix updateWithMomentum = lastWeightUpdateMatrix .times (momentum )
179- .plus (momentumLessUpdate .times (1.0 - momentum ));
180- penultimateWeightUpdateMatrix = lastWeightUpdateMatrix .copy (); // done
181- // only
182- // to
183- // implement
184- // VLBP
185- // later
186- lastWeightUpdateMatrix = updateWithMomentum .copy ();
187- return updateWithMomentum ;
188- }
189-
190134 public Matrix getLastWeightUpdateMatrix () {
191135 return lastWeightUpdateMatrix ;
192136 }
193137
138+ public void setLastWeightUpdateMatrix (Matrix m ) {
139+ lastWeightUpdateMatrix = m ;
140+ }
141+
194142 public Matrix getPenultimateWeightUpdateMatrix () {
195143 return penultimateWeightUpdateMatrix ;
196144 }
197145
198- public Vector calculateBiasUpdates (double alpha ) {
199- Matrix biasUpdateMatrix = lastSensitivityMatrix .times (alpha )
200- .times (-1.0 );
201-
202- Vector result = new Vector (biasUpdateMatrix .getRowDimension ());
203- for (int i = 0 ; i < biasUpdateMatrix .getRowDimension (); i ++) {
204- result .setValue (i , biasUpdateMatrix .get (i , 0 ));
205- }
206- penultimateBiasUpdateVector = lastBiasUpdateVector .copyVector ();
207- lastBiasUpdateVector = result .copyVector ();
208- return result ;
209- }
210-
211- public Vector calculateBiasUpdates (double alpha , double momentum ) {
212- Matrix biasUpdateMatrixWithoutMomentum = lastSensitivityMatrix .times (
213- alpha ).times (-1.0 );
214- ;
215- Matrix biasUpdateMatrixWithMomentum = lastBiasUpdateVector .times (
216- momentum ).plus (
217- biasUpdateMatrixWithoutMomentum .times (1.0 - momentum ));
218- Vector result = new Vector (biasUpdateMatrixWithMomentum
219- .getRowDimension ());
220- for (int i = 0 ; i < biasUpdateMatrixWithMomentum .getRowDimension (); i ++) {
221- result .setValue (i , biasUpdateMatrixWithMomentum .get (i , 0 ));
222- }
223- penultimateBiasUpdateVector = lastBiasUpdateVector .copyVector ();
224- lastBiasUpdateVector = result .copyVector ();
225- return result ;
146+ public void setPenultimateWeightUpdateMatrix (Matrix m ) {
147+ penultimateWeightUpdateMatrix = m ;
226148 }
227149
228150 public Vector getLastBiasUpdateVector () {
229151 return lastBiasUpdateVector ;
230152 }
231153
154+ public void setLastBiasUpdateVector (Vector v ) {
155+ lastBiasUpdateVector = v ;
156+ }
157+
232158 public Vector getPenultimateBiasUpdateVector () {
233159 return penultimateBiasUpdateVector ;
234160 }
235161
162+ public void setPenultimateBiasUpdateVector (Vector v ) {
163+ penultimateBiasUpdateVector = v ;
164+ }
165+
236166 public void updateWeights () {
237167 weightMatrix .plusEquals (lastWeightUpdateMatrix );
238168 }
@@ -251,4 +181,13 @@ public Vector getLastInputValues() {
251181 return lastInput ;
252182
253183 }
184+
185+ public ActivationFunction getActivationFunction () {
186+
187+ return activationFunction ;
188+ }
189+
190+ public void setSensitivityMatrix (Matrix sensitivityMatrix ) {
191+ this .mySensitivityMatrix = sensitivityMatrix ;
192+ }
254193}
0 commit comments