Skip to content

Commit 14a46db

Browse files
author
magicindian
committed
sensitivity removed completely from class Layer.
1 parent c80e3c2 commit 14a46db

File tree

4 files changed

+48
-79
lines changed

4 files changed

+48
-79
lines changed

src/aima/learning/neural/BackPropLearning.java

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,10 @@ public Vector processInput(FeedForwardNeuralNetwork network, Vector input) {
3333
public void processError(FeedForwardNeuralNetwork network, Vector error) {
3434
// TODO calculate total error somewhere
3535
// create Sensitivity Matrices
36-
outputLayer.setSensitivityMatrix(outputSensitivity
37-
.sensitivityMatrixFromErrorMatrix(error));
36+
outputSensitivity.sensitivityMatrixFromErrorMatrix(error);
3837

39-
hiddenLayer.setSensitivityMatrix(hiddenSensitivity
40-
.sensitivityMatrixFromSucceedingLayer(outputLayer));
38+
hiddenSensitivity
39+
.sensitivityMatrixFromSucceedingLayer(outputSensitivity);
4140

4241
// calculate weight Updates
4342
calculateWeightUpdates(outputSensitivity, hiddenLayer
@@ -62,18 +61,11 @@ public Matrix calculateWeightUpdates(LayerSensitivity layerSensitivity,
6261
Vector previousLayerActivationOrInput, double alpha, double momentum) {
6362
Layer layer = layerSensitivity.getLayer();
6463
Matrix activationTranspose = previousLayerActivationOrInput.transpose();
65-
Matrix momentumLessUpdate = layer.getSensitivityMatrix().times(
66-
activationTranspose).times(alpha).times(-1.0);
64+
Matrix momentumLessUpdate = layerSensitivity.getSensitivityMatrix()
65+
.times(activationTranspose).times(alpha).times(-1.0);
6766
Matrix updateWithMomentum = layer.getLastWeightUpdateMatrix().times(
6867
momentum).plus(momentumLessUpdate.times(1.0 - momentum));
69-
layer.setPenultimateWeightUpdateMatrix(layer
70-
.getLastWeightUpdateMatrix().copy()); // done
71-
// only
72-
// to
73-
// implement
74-
// VLBP
75-
// later
76-
layer.setLastWeightUpdateMatrix(updateWithMomentum.copy());
68+
layer.acceptNewWeightUpdate(updateWithMomentum.copy());
7769
return updateWithMomentum;
7870
}
7971

@@ -84,17 +76,15 @@ public static Matrix calculateWeightUpdates(
8476
Matrix activationTranspose = previousLayerActivationOrInput.transpose();
8577
Matrix weightUpdateMatrix = layerSensitivity.getSensitivityMatrix()
8678
.times(activationTranspose).times(alpha).times(-1.0);
87-
layer.setPenultimateWeightUpdateMatrix(layer
88-
.getLastWeightUpdateMatrix().copy());
89-
layer.setLastWeightUpdateMatrix(weightUpdateMatrix.copy());
79+
layer.acceptNewWeightUpdate(weightUpdateMatrix.copy());
9080
return weightUpdateMatrix;
9181
}
9282

9383
public Vector calculateBiasUpdates(LayerSensitivity layerSensitivity,
9484
double alpha, double momentum) {
9585
Layer layer = layerSensitivity.getLayer();
96-
Matrix biasUpdateMatrixWithoutMomentum = layer.getSensitivityMatrix()
97-
.times(alpha).times(-1.0);
86+
Matrix biasUpdateMatrixWithoutMomentum = layerSensitivity
87+
.getSensitivityMatrix().times(alpha).times(-1.0);
9888

9989
Matrix biasUpdateMatrixWithMomentum = layer.getLastBiasUpdateVector()
10090
.times(momentum).plus(
@@ -104,25 +94,21 @@ public Vector calculateBiasUpdates(LayerSensitivity layerSensitivity,
10494
for (int i = 0; i < biasUpdateMatrixWithMomentum.getRowDimension(); i++) {
10595
result.setValue(i, biasUpdateMatrixWithMomentum.get(i, 0));
10696
}
107-
layer.setPenultimateBiasUpdateVector(layer.getLastBiasUpdateVector()
108-
.copyVector());
109-
layer.setLastBiasUpdateVector(result.copyVector());
97+
layer.acceptNewBiasUpdate(result.copyVector());
11098
return result;
11199
}
112100

113101
public static Vector calculateBiasUpdates(
114102
LayerSensitivity layerSensitivity, double alpha) {
115103
Layer layer = layerSensitivity.getLayer();
116-
Matrix biasUpdateMatrix = layer.getSensitivityMatrix().times(alpha)
117-
.times(-1.0);
104+
Matrix biasUpdateMatrix = layerSensitivity.getSensitivityMatrix()
105+
.times(alpha).times(-1.0);
118106

119107
Vector result = new Vector(biasUpdateMatrix.getRowDimension());
120108
for (int i = 0; i < biasUpdateMatrix.getRowDimension(); i++) {
121109
result.setValue(i, biasUpdateMatrix.get(i, 0));
122110
}
123-
layer.setPenultimateBiasUpdateVector(layer.getLastBiasUpdateVector()
124-
.copyVector());
125-
layer.setLastBiasUpdateVector(result.copyVector());
111+
layer.acceptNewBiasUpdate(result.copyVector());
126112
return result;
127113
}
128114
}

src/aima/learning/neural/Layer.java

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ public class Layer {
1414

1515
private Vector lastActivationValues, lastInducedField;
1616

17-
private Matrix mySensitivityMatrix;
18-
1917
private Matrix lastWeightUpdateMatrix;
2018

2119
private Matrix penultimateWeightUpdateMatrix;
@@ -32,8 +30,7 @@ public Layer(Matrix weightMatrix, Vector biasVector, ActivationFunction af) {
3230
weightMatrix.getColumnDimension());
3331
penultimateWeightUpdateMatrix = new Matrix(weightMatrix
3432
.getRowDimension(), weightMatrix.getColumnDimension());
35-
mySensitivityMatrix = new Matrix(weightMatrix.getRowDimension(),
36-
weightMatrix.getColumnDimension());
33+
3734
this.biasVector = biasVector;
3835
lastBiasUpdateVector = new Vector(biasVector.getRowDimension());
3936
penultimateBiasUpdateVector = new Vector(biasVector.getRowDimension());
@@ -42,7 +39,7 @@ public Layer(Matrix weightMatrix, Vector biasVector, ActivationFunction af) {
4239
public Layer(int numberOfNeurons, int numberOfInputs,
4340
double lowerLimitForWeights, double upperLimitForWeights,
4441
ActivationFunction af) {
45-
// sensitivityMatrix = new Matrix(numberOfNeurons, numberOfInputs);
42+
4643
activationFunction = af;
4744
this.weightMatrix = new Matrix(numberOfNeurons, numberOfInputs());
4845
lastWeightUpdateMatrix = new Matrix(weightMatrix.getRowDimension(),
@@ -88,11 +85,6 @@ public Vector getBiasVector() {
8885
return biasVector;
8986
}
9087

91-
public Matrix getSensitivityMatrix() {
92-
93-
return mySensitivityMatrix;
94-
}
95-
9688
public int numberOfNeurons() {
9789
return weightMatrix.getRowDimension();
9890
}
@@ -187,7 +179,17 @@ public ActivationFunction getActivationFunction() {
187179
return activationFunction;
188180
}
189181

190-
public void setSensitivityMatrix(Matrix sensitivityMatrix) {
191-
this.mySensitivityMatrix = sensitivityMatrix;
182+
public void acceptNewWeightUpdate(Matrix weightUpdate) {
183+
/*
184+
* penultimate weightupdates maintained only to implement VLBP later
185+
*/
186+
setPenultimateWeightUpdateMatrix(getLastWeightUpdateMatrix());
187+
setLastWeightUpdateMatrix(weightUpdate);
192188
}
189+
190+
public void acceptNewBiasUpdate(Vector biasUpdate) {
191+
setPenultimateBiasUpdateVector(getLastBiasUpdateVector());
192+
setLastBiasUpdateVector(biasUpdate);
193+
}
194+
193195
}

src/aima/learning/neural/LayerSensitivity.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,15 @@ public Matrix sensitivityMatrixFromErrorMatrix(Vector errorVector) {
3535
return calculatedSensitivityMatrix;
3636
}
3737

38-
public Matrix sensitivityMatrixFromSucceedingLayer(Layer nextLayer) {
38+
public Matrix sensitivityMatrixFromSucceedingLayer(
39+
LayerSensitivity nextLayerSensitivity) {
40+
Layer nextLayer = nextLayerSensitivity.getLayer();
3941
Matrix derivativeMatrix = createDerivativeMatrix(layer
4042
.getLastInducedField());
4143
Matrix weightTranspose = nextLayer.getWeightMatrix().transpose();
4244
Matrix calculatedSensitivityMatrix = derivativeMatrix.times(
43-
weightTranspose).times(nextLayer.getSensitivityMatrix());
45+
weightTranspose).times(
46+
nextLayerSensitivity.getSensitivityMatrix());
4447
sensitivityMatrix = calculatedSensitivityMatrix.copy();
4548
return sensitivityMatrix;
4649
}

src/aima/test/learningtest/neural/LayerTests.java

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,10 @@ public void testSensitivityMatrixCalculationFromErrorVector() {
8282
Vector errorVector = new Vector(1);
8383
errorVector.setValue(0, 1.261);
8484
LayerSensitivity layer2Sensitivity = new LayerSensitivity(layer2);
85-
layer2.setSensitivityMatrix(layer2Sensitivity
86-
.sensitivityMatrixFromErrorMatrix(errorVector));
85+
layer2Sensitivity.sensitivityMatrixFromErrorMatrix(errorVector);
8786

88-
Matrix sensitivityMatrix = layer2.getSensitivityMatrix();
87+
Matrix sensitivityMatrix = layer2Sensitivity.getSensitivityMatrix();
8988
assertEquals(-2.522, sensitivityMatrix.get(0, 0));
90-
// System.out.println(sensistivityMatrix);
9189

9290
}
9391

@@ -124,14 +122,11 @@ public void testSensitivityMatrixCalculationFromSucceedingLayer() {
124122
Vector errorVector = new Vector(1);
125123
errorVector.setValue(0, 1.261);
126124
LayerSensitivity layer2Sensitivity = new LayerSensitivity(layer2);
127-
layer2.setSensitivityMatrix(layer2Sensitivity
128-
.sensitivityMatrixFromErrorMatrix(errorVector));
125+
layer2Sensitivity.sensitivityMatrixFromErrorMatrix(errorVector);
129126

130-
// Matrix sensitivityMatrix = layer1
131-
// .sensitivityMatrixFromSucceedingLayer(layer2);
132-
layer1.setSensitivityMatrix(layer1Sensitivity
133-
.sensitivityMatrixFromSucceedingLayer(layer2));
134-
Matrix sensitivityMatrix = layer1.getSensitivityMatrix();
127+
layer1Sensitivity
128+
.sensitivityMatrixFromSucceedingLayer(layer2Sensitivity);
129+
Matrix sensitivityMatrix = layer1Sensitivity.getSensitivityMatrix();
135130

136131
assertEquals(2, sensitivityMatrix.getRowDimension());
137132
assertEquals(1, sensitivityMatrix.getColumnDimension());
@@ -173,15 +168,11 @@ public void testWeightUpdateMatrixesFormedCorrectly() {
173168
Vector errorVector = new Vector(1);
174169
errorVector.setValue(0, 1.261);
175170
LayerSensitivity layer2Sensitivity = new LayerSensitivity(layer2);
176-
layer2.setSensitivityMatrix(layer2Sensitivity
177-
.sensitivityMatrixFromErrorMatrix(errorVector));
171+
layer2Sensitivity.sensitivityMatrixFromErrorMatrix(errorVector);
178172

179-
// layer1.sensitivityMatrixFromSucceedingLayer(layer2);
180-
layer1.setSensitivityMatrix(layer1Sensitivity
181-
.sensitivityMatrixFromSucceedingLayer(layer2));
173+
layer1Sensitivity
174+
.sensitivityMatrixFromSucceedingLayer(layer2Sensitivity);
182175

183-
// Matrix weightUpdateMatrix2 = layer2.calculateWeightUpdates(layer1
184-
// .getLastActivationValues(), 0.1);
185176
Matrix weightUpdateMatrix2 = BackPropLearning.calculateWeightUpdates(
186177
layer2Sensitivity, layer1.getLastActivationValues(), 0.1);
187178
assertEquals(0.0809, weightUpdateMatrix2.get(0, 0), 0.001);
@@ -196,8 +187,6 @@ public void testWeightUpdateMatrixesFormedCorrectly() {
196187
assertEquals(0.0, penultimateWeightUpdatematrix2.get(0, 0), 0.001);
197188
assertEquals(0.0, penultimateWeightUpdatematrix2.get(0, 1), 0.001);
198189

199-
// Matrix weightUpdateMatrix1 = layer1.calculateWeightUpdates(
200-
// inputVector1, 0.1);
201190
Matrix weightUpdateMatrix1 = BackPropLearning.calculateWeightUpdates(
202191
layer1Sensitivity, inputVector1, 0.1);
203192
assertEquals(0.0049, weightUpdateMatrix1.get(0, 0), 0.001);
@@ -210,7 +199,6 @@ public void testWeightUpdateMatrixesFormedCorrectly() {
210199
.getPenultimateWeightUpdateMatrix();
211200
assertEquals(0.0, penultimateWeightUpdatematrix1.get(0, 0), 0.001);
212201
assertEquals(0.0, penultimateWeightUpdatematrix1.get(1, 0), 0.001);
213-
// System.out.println(weightUpdateMatrix1);
214202

215203
}
216204

@@ -247,13 +235,11 @@ public void testBiasUpdateMatrixesFormedCorrectly() {
247235

248236
Vector errorVector = new Vector(1);
249237
errorVector.setValue(0, 1.261);
250-
layer2.setSensitivityMatrix(layer2Sensitivity
251-
.sensitivityMatrixFromErrorMatrix(errorVector));
252-
// layer1.sensitivityMatrixFromSucceedingLayer(layer2);
253-
layer1.setSensitivityMatrix(layer1Sensitivity
254-
.sensitivityMatrixFromSucceedingLayer(layer2));
238+
layer2Sensitivity.sensitivityMatrixFromErrorMatrix(errorVector);
239+
240+
layer1Sensitivity
241+
.sensitivityMatrixFromSucceedingLayer(layer2Sensitivity);
255242

256-
// Vector biasUpdateVector2 = layer2.calculateBiasUpdates(0.1);
257243
Vector biasUpdateVector2 = BackPropLearning.calculateBiasUpdates(
258244
layer2Sensitivity, 0.1);
259245
assertEquals(0.2522, biasUpdateVector2.getValue(0), 0.001);
@@ -265,7 +251,6 @@ public void testBiasUpdateMatrixesFormedCorrectly() {
265251
.getPenultimateBiasUpdateVector();
266252
assertEquals(0.0, penultimateBiasUpdateVector2.getValue(0), 0.001);
267253

268-
// Vector biasUpdateVector1 = layer1.calculateBiasUpdates(0.1);
269254
Vector biasUpdateVector1 = BackPropLearning.calculateBiasUpdates(
270255
layer1Sensitivity, 0.1);
271256
assertEquals(0.00495, biasUpdateVector1.getValue(0), 0.001);
@@ -316,26 +301,19 @@ public void testWeightsAndBiasesUpdatedCorrectly() {
316301
Vector errorVector = new Vector(1);
317302
errorVector.setValue(0, 1.261);
318303
LayerSensitivity layer2Sensitivity = new LayerSensitivity(layer2);
319-
layer2.setSensitivityMatrix(layer2Sensitivity
320-
.sensitivityMatrixFromErrorMatrix(errorVector));
321-
// layer1.sensitivityMatrixFromSucceedingLayer(layer2);
322-
layer1.setSensitivityMatrix(layer1Sensitivity
323-
.sensitivityMatrixFromSucceedingLayer(layer2));
304+
layer2Sensitivity.sensitivityMatrixFromErrorMatrix(errorVector);
324305

325-
// layer2.calculateWeightUpdates(layer1.getLastActivationValues(), 0.1);
306+
layer1Sensitivity
307+
.sensitivityMatrixFromSucceedingLayer(layer2Sensitivity);
326308

327309
BackPropLearning.calculateWeightUpdates(layer2Sensitivity, layer1
328310
.getLastActivationValues(), 0.1);
329311

330-
// layer2.calculateBiasUpdates(0.1);
331312
BackPropLearning.calculateBiasUpdates(layer2Sensitivity, 0.1);
332313

333-
// layer1.calculateWeightUpdates(inputVector1, 0.1);
334-
335314
BackPropLearning.calculateWeightUpdates(layer1Sensitivity,
336315
inputVector1, 0.1);
337316

338-
// layer1.calculateBiasUpdates(0.1);
339317
BackPropLearning.calculateBiasUpdates(layer1Sensitivity, 0.1);
340318

341319
layer2.updateWeights();

0 commit comments

Comments
 (0)