Skip to content

Commit c80e3c2

Browse files
author
magicindian
committed
all weight and bias updates moved out of Layer class
1 parent 84d849a commit c80e3c2

File tree

5 files changed

+220
-124
lines changed

5 files changed

+220
-124
lines changed
Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
package aima.learning.neural;
22

3+
import aima.util.Matrix;
4+
35
public class BackPropLearning implements NNTrainingScheme {
46
private final double learningRate;
57
private final double momentum;
68

79
private final Layer hiddenLayer;
810
private final Layer outputLayer;
11+
private final LayerSensitivity hiddenSensitivity;
12+
private final LayerSensitivity outputSensitivity;
913

1014
public BackPropLearning(FeedForwardNeuralNetwork network,
1115
double learningRate, double momentum) {
1216

1317
this.hiddenLayer = network.getHiddenLayer();
1418
this.outputLayer = network.getOutputLayer();
19+
hiddenSensitivity = new LayerSensitivity(hiddenLayer);
20+
outputSensitivity = new LayerSensitivity(outputLayer);
1521
this.learningRate = learningRate;
1622
this.momentum = momentum;
23+
1724
}
1825

1926
public Vector processInput(FeedForwardNeuralNetwork network, Vector input) {
@@ -26,19 +33,21 @@ public Vector processInput(FeedForwardNeuralNetwork network, Vector input) {
2633
public void processError(FeedForwardNeuralNetwork network, Vector error) {
2734
// TODO calculate total error somewhere
2835
// create Sensitivity Matrices
29-
outputLayer.sensitivityMatrixFromErrorMatrix(error);
30-
hiddenLayer.sensitivityMatrixFromSucceedingLayer(network
31-
.getOutputLayer());
36+
outputLayer.setSensitivityMatrix(outputSensitivity
37+
.sensitivityMatrixFromErrorMatrix(error));
38+
39+
hiddenLayer.setSensitivityMatrix(hiddenSensitivity
40+
.sensitivityMatrixFromSucceedingLayer(outputLayer));
3241

3342
// calculate weight Updates
34-
outputLayer.calculateWeightUpdates(hiddenLayer
43+
calculateWeightUpdates(outputSensitivity, hiddenLayer
3544
.getLastActivationValues(), learningRate, momentum);
36-
hiddenLayer.calculateWeightUpdates(hiddenLayer.getLastInputValues(),
37-
learningRate, momentum);
45+
calculateWeightUpdates(hiddenSensitivity, hiddenLayer
46+
.getLastInputValues(), learningRate, momentum);
3847

3948
// calculate Bias Updates
40-
outputLayer.calculateBiasUpdates(learningRate, momentum);
41-
hiddenLayer.calculateBiasUpdates(learningRate, momentum);
49+
calculateBiasUpdates(outputSensitivity, learningRate, momentum);
50+
calculateBiasUpdates(hiddenSensitivity, learningRate, momentum);
4251

4352
// update weightsAndBiases
4453
outputLayer.updateWeights();
@@ -48,4 +57,72 @@ public void processError(FeedForwardNeuralNetwork network, Vector error) {
4857
hiddenLayer.updateBiases();
4958

5059
}
60+
61+
public Matrix calculateWeightUpdates(LayerSensitivity layerSensitivity,
62+
Vector previousLayerActivationOrInput, double alpha, double momentum) {
63+
Layer layer = layerSensitivity.getLayer();
64+
Matrix activationTranspose = previousLayerActivationOrInput.transpose();
65+
Matrix momentumLessUpdate = layer.getSensitivityMatrix().times(
66+
activationTranspose).times(alpha).times(-1.0);
67+
Matrix updateWithMomentum = layer.getLastWeightUpdateMatrix().times(
68+
momentum).plus(momentumLessUpdate.times(1.0 - momentum));
69+
layer.setPenultimateWeightUpdateMatrix(layer
70+
.getLastWeightUpdateMatrix().copy()); // done
71+
// only
72+
// to
73+
// implement
74+
// VLBP
75+
// later
76+
layer.setLastWeightUpdateMatrix(updateWithMomentum.copy());
77+
return updateWithMomentum;
78+
}
79+
80+
public static Matrix calculateWeightUpdates(
81+
LayerSensitivity layerSensitivity,
82+
Vector previousLayerActivationOrInput, double alpha) {
83+
Layer layer = layerSensitivity.getLayer();
84+
Matrix activationTranspose = previousLayerActivationOrInput.transpose();
85+
Matrix weightUpdateMatrix = layerSensitivity.getSensitivityMatrix()
86+
.times(activationTranspose).times(alpha).times(-1.0);
87+
layer.setPenultimateWeightUpdateMatrix(layer
88+
.getLastWeightUpdateMatrix().copy());
89+
layer.setLastWeightUpdateMatrix(weightUpdateMatrix.copy());
90+
return weightUpdateMatrix;
91+
}
92+
93+
public Vector calculateBiasUpdates(LayerSensitivity layerSensitivity,
94+
double alpha, double momentum) {
95+
Layer layer = layerSensitivity.getLayer();
96+
Matrix biasUpdateMatrixWithoutMomentum = layer.getSensitivityMatrix()
97+
.times(alpha).times(-1.0);
98+
99+
Matrix biasUpdateMatrixWithMomentum = layer.getLastBiasUpdateVector()
100+
.times(momentum).plus(
101+
biasUpdateMatrixWithoutMomentum.times(1.0 - momentum));
102+
Vector result = new Vector(biasUpdateMatrixWithMomentum
103+
.getRowDimension());
104+
for (int i = 0; i < biasUpdateMatrixWithMomentum.getRowDimension(); i++) {
105+
result.setValue(i, biasUpdateMatrixWithMomentum.get(i, 0));
106+
}
107+
layer.setPenultimateBiasUpdateVector(layer.getLastBiasUpdateVector()
108+
.copyVector());
109+
layer.setLastBiasUpdateVector(result.copyVector());
110+
return result;
111+
}
112+
113+
public static Vector calculateBiasUpdates(
114+
LayerSensitivity layerSensitivity, double alpha) {
115+
Layer layer = layerSensitivity.getLayer();
116+
Matrix biasUpdateMatrix = layer.getSensitivityMatrix().times(alpha)
117+
.times(-1.0);
118+
119+
Vector result = new Vector(biasUpdateMatrix.getRowDimension());
120+
for (int i = 0; i < biasUpdateMatrix.getRowDimension(); i++) {
121+
result.setValue(i, biasUpdateMatrix.get(i, 0));
122+
}
123+
layer.setPenultimateBiasUpdateVector(layer.getLastBiasUpdateVector()
124+
.copyVector());
125+
layer.setLastBiasUpdateVector(result.copyVector());
126+
return result;
127+
}
51128
}

src/aima/learning/neural/Layer.java

Lines changed: 28 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
package aima.learning.neural;
22

3-
import java.util.ArrayList;
4-
import java.util.List;
5-
63
import aima.learning.statistics.ActivationFunction;
74
import aima.util.Matrix;
85
import aima.util.Util;
@@ -17,7 +14,7 @@ public class Layer {
1714

1815
private Vector lastActivationValues, lastInducedField;
1916

20-
private Matrix lastSensitivityMatrix;
17+
private Matrix mySensitivityMatrix;
2118

2219
private Matrix lastWeightUpdateMatrix;
2320

@@ -35,8 +32,8 @@ public Layer(Matrix weightMatrix, Vector biasVector, ActivationFunction af) {
3532
weightMatrix.getColumnDimension());
3633
penultimateWeightUpdateMatrix = new Matrix(weightMatrix
3734
.getRowDimension(), weightMatrix.getColumnDimension());
38-
// sensitivityMatrix = new Matrix(weightMatrix.getRowDimension(),
39-
// weightMatrix.getColumnDimension());
35+
mySensitivityMatrix = new Matrix(weightMatrix.getRowDimension(),
36+
weightMatrix.getColumnDimension());
4037
this.biasVector = biasVector;
4138
lastBiasUpdateVector = new Vector(biasVector.getRowDimension());
4239
penultimateBiasUpdateVector = new Vector(biasVector.getRowDimension());
@@ -91,26 +88,9 @@ public Vector getBiasVector() {
9188
return biasVector;
9289
}
9390

94-
public Matrix sensitivityMatrixFromErrorMatrix(Vector errorVector) {
95-
Matrix derivativeMatrix = createDerivativeMatrix(lastInducedField);
96-
Matrix sensitivityMatrix = derivativeMatrix.times(errorVector).times(
97-
-2.0);
98-
lastSensitivityMatrix = sensitivityMatrix.copy();
99-
return sensitivityMatrix;
100-
}
101-
102-
public Matrix sensitivityMatrixFromSucceedingLayer(Layer nextLayer) {
103-
Matrix derivativeMatrix = createDerivativeMatrix(lastInducedField);
104-
Matrix weightTranspose = nextLayer.weightMatrix.transpose();
105-
Matrix sensitivityMatrix = derivativeMatrix.times(weightTranspose)
106-
.times(nextLayer.getSensitivityMatrix());
107-
lastSensitivityMatrix = sensitivityMatrix.copy();
108-
return sensitivityMatrix;
109-
}
110-
111-
private Matrix getSensitivityMatrix() {
91+
public Matrix getSensitivityMatrix() {
11292

113-
return lastSensitivityMatrix;
93+
return mySensitivityMatrix;
11494
}
11595

11696
public int numberOfNeurons() {
@@ -151,88 +131,38 @@ private static void initializeVector(Vector aVector, double lowerLimit,
151131
}
152132
}
153133

154-
private Matrix createDerivativeMatrix(Vector lastInducedField) {
155-
List<Double> lst = new ArrayList<Double>();
156-
for (int i = 0; i < lastInducedField.size(); i++) {
157-
lst.add(new Double(activationFunction.deriv(lastInducedField
158-
.getValue(i))));
159-
}
160-
return Matrix.createDiagonalMatrix(lst);
161-
}
162-
163-
public Matrix calculateWeightUpdates(Vector previousLayerActivationOrInput,
164-
double alpha) {
165-
Matrix activationTranspose = previousLayerActivationOrInput.transpose();
166-
Matrix weightUpdateMatrix = lastSensitivityMatrix.times(
167-
activationTranspose).times(alpha).times(-1.0);
168-
penultimateWeightUpdateMatrix = lastWeightUpdateMatrix.copy();
169-
lastWeightUpdateMatrix = weightUpdateMatrix.copy();
170-
return weightUpdateMatrix;
171-
}
172-
173-
public Matrix calculateWeightUpdates(Vector previousLayerActivationOrInput,
174-
double alpha, double momentum) {
175-
Matrix activationTranspose = previousLayerActivationOrInput.transpose();
176-
Matrix momentumLessUpdate = lastSensitivityMatrix.times(
177-
activationTranspose).times(alpha).times(-1.0);
178-
Matrix updateWithMomentum = lastWeightUpdateMatrix.times(momentum)
179-
.plus(momentumLessUpdate.times(1.0 - momentum));
180-
penultimateWeightUpdateMatrix = lastWeightUpdateMatrix.copy(); // done
181-
// only
182-
// to
183-
// implement
184-
// VLBP
185-
// later
186-
lastWeightUpdateMatrix = updateWithMomentum.copy();
187-
return updateWithMomentum;
188-
}
189-
190134
public Matrix getLastWeightUpdateMatrix() {
191135
return lastWeightUpdateMatrix;
192136
}
193137

138+
public void setLastWeightUpdateMatrix(Matrix m) {
139+
lastWeightUpdateMatrix = m;
140+
}
141+
194142
public Matrix getPenultimateWeightUpdateMatrix() {
195143
return penultimateWeightUpdateMatrix;
196144
}
197145

198-
public Vector calculateBiasUpdates(double alpha) {
199-
Matrix biasUpdateMatrix = lastSensitivityMatrix.times(alpha)
200-
.times(-1.0);
201-
202-
Vector result = new Vector(biasUpdateMatrix.getRowDimension());
203-
for (int i = 0; i < biasUpdateMatrix.getRowDimension(); i++) {
204-
result.setValue(i, biasUpdateMatrix.get(i, 0));
205-
}
206-
penultimateBiasUpdateVector = lastBiasUpdateVector.copyVector();
207-
lastBiasUpdateVector = result.copyVector();
208-
return result;
209-
}
210-
211-
public Vector calculateBiasUpdates(double alpha, double momentum) {
212-
Matrix biasUpdateMatrixWithoutMomentum = lastSensitivityMatrix.times(
213-
alpha).times(-1.0);
214-
;
215-
Matrix biasUpdateMatrixWithMomentum = lastBiasUpdateVector.times(
216-
momentum).plus(
217-
biasUpdateMatrixWithoutMomentum.times(1.0 - momentum));
218-
Vector result = new Vector(biasUpdateMatrixWithMomentum
219-
.getRowDimension());
220-
for (int i = 0; i < biasUpdateMatrixWithMomentum.getRowDimension(); i++) {
221-
result.setValue(i, biasUpdateMatrixWithMomentum.get(i, 0));
222-
}
223-
penultimateBiasUpdateVector = lastBiasUpdateVector.copyVector();
224-
lastBiasUpdateVector = result.copyVector();
225-
return result;
146+
public void setPenultimateWeightUpdateMatrix(Matrix m) {
147+
penultimateWeightUpdateMatrix = m;
226148
}
227149

228150
public Vector getLastBiasUpdateVector() {
229151
return lastBiasUpdateVector;
230152
}
231153

154+
public void setLastBiasUpdateVector(Vector v) {
155+
lastBiasUpdateVector = v;
156+
}
157+
232158
public Vector getPenultimateBiasUpdateVector() {
233159
return penultimateBiasUpdateVector;
234160
}
235161

162+
public void setPenultimateBiasUpdateVector(Vector v) {
163+
penultimateBiasUpdateVector = v;
164+
}
165+
236166
public void updateWeights() {
237167
weightMatrix.plusEquals(lastWeightUpdateMatrix);
238168
}
@@ -251,4 +181,13 @@ public Vector getLastInputValues() {
251181
return lastInput;
252182

253183
}
184+
185+
public ActivationFunction getActivationFunction() {
186+
187+
return activationFunction;
188+
}
189+
190+
public void setSensitivityMatrix(Matrix sensitivityMatrix) {
191+
this.mySensitivityMatrix = sensitivityMatrix;
192+
}
254193
}
Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package aima.learning.neural;
22

3+
import java.util.ArrayList;
4+
import java.util.List;
5+
36
import aima.util.Matrix;
47

58
public class LayerSensitivity {
@@ -8,17 +11,51 @@ public class LayerSensitivity {
811
* Used for backprop learning
912
*/
1013

11-
private final Matrix sensitivityMatrix;
14+
private Matrix sensitivityMatrix;
15+
private final Layer layer;
1216

1317
public LayerSensitivity(Layer layer) {
1418
Matrix weightMatrix = layer.getWeightMatrix();
1519
this.sensitivityMatrix = new Matrix(weightMatrix.getRowDimension(),
1620
weightMatrix.getColumnDimension());
21+
this.layer = layer;
1722

1823
}
1924

2025
public Matrix getSensitivityMatrix() {
2126
return sensitivityMatrix;
2227
}
2328

29+
public Matrix sensitivityMatrixFromErrorMatrix(Vector errorVector) {
30+
Matrix derivativeMatrix = createDerivativeMatrix(layer
31+
.getLastInducedField());
32+
Matrix calculatedSensitivityMatrix = derivativeMatrix
33+
.times(errorVector).times(-2.0);
34+
sensitivityMatrix = calculatedSensitivityMatrix.copy();
35+
return calculatedSensitivityMatrix;
36+
}
37+
38+
public Matrix sensitivityMatrixFromSucceedingLayer(Layer nextLayer) {
39+
Matrix derivativeMatrix = createDerivativeMatrix(layer
40+
.getLastInducedField());
41+
Matrix weightTranspose = nextLayer.getWeightMatrix().transpose();
42+
Matrix calculatedSensitivityMatrix = derivativeMatrix.times(
43+
weightTranspose).times(nextLayer.getSensitivityMatrix());
44+
sensitivityMatrix = calculatedSensitivityMatrix.copy();
45+
return sensitivityMatrix;
46+
}
47+
48+
private Matrix createDerivativeMatrix(Vector lastInducedField) {
49+
List<Double> lst = new ArrayList<Double>();
50+
for (int i = 0; i < lastInducedField.size(); i++) {
51+
lst.add(new Double(layer.getActivationFunction().deriv(
52+
lastInducedField.getValue(i))));
53+
}
54+
return Matrix.createDiagonalMatrix(lst);
55+
}
56+
57+
public Layer getLayer() {
58+
return layer;
59+
}
60+
2461
}

src/aima/test/learningtest/neural/AllNeuralTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ public static Test suite() {
88
TestSuite suite = new TestSuite("All tests for NN Implementation");
99

1010
suite.addTest(new TestSuite(LayerTests.class));
11-
// suite.addTest(new TestSuite(UtilTests.class));
1211
suite.addTest(new TestSuite(DataSetTests.class));
1312
suite.addTest(new TestSuite(FeedForwardNeuralNetworkTests.class));
1413
return suite;

0 commit comments

Comments
 (0)