Skip to content

Commit 599574b

Browse files
author
magicindian
committed
moved momentum and learning rate into BackProp Training scheme
1 parent 14a46db commit 599574b

File tree

6 files changed

+39
-34
lines changed

6 files changed

+39
-34
lines changed

src/aima/learning/neural/BackPropLearning.java

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,26 @@ public class BackPropLearning implements NNTrainingScheme {
66
private final double learningRate;
77
private final double momentum;
88

9-
private final Layer hiddenLayer;
10-
private final Layer outputLayer;
11-
private final LayerSensitivity hiddenSensitivity;
12-
private final LayerSensitivity outputSensitivity;
13-
14-
public BackPropLearning(FeedForwardNeuralNetwork network,
15-
double learningRate, double momentum) {
16-
17-
this.hiddenLayer = network.getHiddenLayer();
18-
this.outputLayer = network.getOutputLayer();
19-
hiddenSensitivity = new LayerSensitivity(hiddenLayer);
20-
outputSensitivity = new LayerSensitivity(outputLayer);
9+
private Layer hiddenLayer;
10+
private Layer outputLayer;
11+
private LayerSensitivity hiddenSensitivity;
12+
private LayerSensitivity outputSensitivity;
13+
14+
public BackPropLearning(double learningRate, double momentum) {
15+
2116
this.learningRate = learningRate;
2217
this.momentum = momentum;
2318

2419
}
2520

21+
public void setNeuralNetwork(FeedForwardNeuralNetwork ffnn) {
22+
23+
this.hiddenLayer = ffnn.getHiddenLayer();
24+
this.outputLayer = ffnn.getOutputLayer();
25+
this.hiddenSensitivity = new LayerSensitivity(hiddenLayer);
26+
this.outputSensitivity = new LayerSensitivity(outputLayer);
27+
}
28+
2629
public Vector processInput(FeedForwardNeuralNetwork network, Vector input) {
2730

2831
hiddenLayer.feedForward(input);

src/aima/learning/neural/FeedForwardNeuralNetwork.java

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,13 @@ public class FeedForwardNeuralNetwork implements FunctionApproximator {
88
private final Layer hiddenLayer;
99
private final Layer outputLayer;
1010

11-
private final double learningRate, momentum;
12-
13-
private final NNTrainingScheme trainingScheme;
11+
private NNTrainingScheme trainingScheme;
1412

1513
/*
16-
* constructor to be used for non testing code for now assume that config
17-
* contains learning rate, momentum parameter, and number of epochs. change
18-
* this later to accomodate varied learning schemes like early stopping
14+
* constructor to be used for non testing code.
1915
*/
2016
public FeedForwardNeuralNetwork(NNConfig config) {
2117

22-
learningRate = config.getParameterAsDouble("learning_rate");
23-
momentum = config.getParameterAsDouble("momentum");
24-
2518
int numberOfInputNeurons = config
2619
.getParameterAsInteger("number_of_inputs");
2720
int numberOfHiddenNeurons = config
@@ -41,7 +34,6 @@ public FeedForwardNeuralNetwork(NNConfig config) {
4134
outputLayer = new Layer(numberOfOutputNeurons, numberOfHiddenNeurons,
4235
lowerLimitForWeights, upperLimitForWeights,
4336
new PureLinearActivationFunction());
44-
trainingScheme = new BackPropLearning(this, learningRate, momentum);
4537

4638
}
4739

@@ -52,14 +44,13 @@ public FeedForwardNeuralNetwork(NNConfig config) {
5244
*/
5345
public FeedForwardNeuralNetwork(Matrix hiddenLayerWeights,
5446
Vector hiddenLayerBias, Matrix outputLayerWeights,
55-
Vector outputLayerBias, double learningRate, double momentum) {
56-
this.learningRate = learningRate;
57-
this.momentum = momentum;
47+
Vector outputLayerBias) {
48+
5849
hiddenLayer = new Layer(hiddenLayerWeights, hiddenLayerBias,
5950
new LogSigActivationFunction());
6051
outputLayer = new Layer(outputLayerWeights, outputLayerBias,
6152
new PureLinearActivationFunction());
62-
trainingScheme = new BackPropLearning(this, learningRate, momentum);
53+
6354
}
6455

6556
public void processError(Vector error) {
@@ -104,4 +95,9 @@ public Layer getOutputLayer() {
10495
return outputLayer;
10596
}
10697

98+
public void setTrainingScheme(NNTrainingScheme trainingScheme) {
99+
this.trainingScheme = trainingScheme;
100+
trainingScheme.setNeuralNetwork(this);
101+
}
102+
107103
}

src/aima/learning/neural/Layer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import aima.util.Util;
66

77
public class Layer {
8-
// vectors are represented by n* 1 matrices;
8+
// vectors are represented by n * 1 matrices;
99
private final Matrix weightMatrix;
1010

1111
Vector biasVector, lastBiasUpdateVector;

src/aima/learning/neural/NNTrainingScheme.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@ public interface NNTrainingScheme {
44
Vector processInput(FeedForwardNeuralNetwork network, Vector input);
55

66
void processError(FeedForwardNeuralNetwork network, Vector error);
7+
8+
void setNeuralNetwork(FeedForwardNeuralNetwork ffnn);
79
}

src/aima/test/learningtest/neural/AllNeuralTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ public class AllNeuralTests {
77
public static Test suite() {
88
TestSuite suite = new TestSuite("All tests for NN Implementation");
99

10+
suite.addTest(new TestSuite(BackPropagationTests.class));
1011
suite.addTest(new TestSuite(LayerTests.class));
1112
suite.addTest(new TestSuite(DataSetTests.class));
12-
suite.addTest(new TestSuite(FeedForwardNeuralNetworkTests.class));
13+
1314
return suite;
1415
}
1516
}

src/aima/test/learningtest/neural/FeedForwardNeuralNetworkTests.java renamed to src/aima/test/learningtest/neural/BackPropagationTests.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package aima.test.learningtest.neural;
22

33
import junit.framework.TestCase;
4+
import aima.learning.neural.BackPropLearning;
45
import aima.learning.neural.FeedForwardNeuralNetwork;
56
import aima.learning.neural.Vector;
67
import aima.util.Matrix;
78

8-
public class FeedForwardNeuralNetworkTests extends TestCase {
9+
public class BackPropagationTests extends TestCase {
910

1011
public void testFeedForwardAndBAckLoopWorks() {
1112
// example 11.14 of Neural Network Design by Hagan, Demuth and Beale
@@ -34,8 +35,9 @@ public void testFeedForwardAndBAckLoopWorks() {
3435
double momentumFactor = 0.0;
3536
FeedForwardNeuralNetwork ffnn = new FeedForwardNeuralNetwork(
3637
hiddenLayerWeightMatrix, hiddenLayerBiasVector,
37-
outputLayerWeightMatrix, outputLayerBiasVector, learningRate,
38-
momentumFactor);
38+
outputLayerWeightMatrix, outputLayerBiasVector);
39+
ffnn.setTrainingScheme(new BackPropLearning(learningRate,
40+
momentumFactor));
3941
ffnn.processInput(input);
4042
ffnn.processError(error);
4143

@@ -83,8 +85,10 @@ public void testFeedForwardAndBAckLoopWorksWithMomentum() {
8385
double momentumFactor = 0.5;
8486
FeedForwardNeuralNetwork ffnn = new FeedForwardNeuralNetwork(
8587
hiddenLayerWeightMatrix, hiddenLayerBiasVector,
86-
outputLayerWeightMatrix, outputLayerBiasVector, learningRate,
87-
momentumFactor);
88+
outputLayerWeightMatrix, outputLayerBiasVector);
89+
90+
ffnn.setTrainingScheme(new BackPropLearning(learningRate,
91+
momentumFactor));
8892
ffnn.processInput(input);
8993
ffnn.processError(error);
9094

@@ -104,5 +108,4 @@ public void testFeedForwardAndBAckLoopWorksWithMomentum() {
104108
assertEquals(0.6061, outputLayerBias.getValue(0), 0.001);
105109

106110
}
107-
108111
}

0 commit comments

Comments
 (0)