Skip to content

Commit c54e424

Browse files
committed
conv nets fixes
1 parent e284b54 commit c54e424

File tree

9 files changed

+111
-47
lines changed

9 files changed

+111
-47
lines changed

nn-core/src/main/java/com/github/neuralnetworks/calculation/neuronfunctions/AparapiAveragePooling2D.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import com.github.neuralnetworks.architecture.Subsampling2DConnection;
88
import com.github.neuralnetworks.calculation.ConnectionCalculator;
99
import com.github.neuralnetworks.calculation.memory.ValuesProvider;
10-
import com.github.neuralnetworks.util.TensorFactory;
1110

1211
/**
1312
* Average pooling
@@ -21,7 +20,7 @@ public class AparapiAveragePooling2D implements ConnectionCalculator {
2120

2221
@Override
2322
public void calculate(List<Connections> connections, ValuesProvider valuesProvider, Layer targetLayer) {
24-
if (cc == null || cc.getMiniBatchSize() != TensorFactory.batchSize(valuesProvider)) {
23+
if (cc == null || !cc.accept((Subsampling2DConnection) connections.get(0), valuesProvider)) {
2524
cc = new AparapiAveragePooling2DCC((Subsampling2DConnection) connections.get(0), valuesProvider, targetLayer);
2625
}
2726

nn-core/src/main/java/com/github/neuralnetworks/calculation/neuronfunctions/AparapiConv2D.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,22 @@ public void run() {
135135
outputStartIndex + (id / outputFeatureMapLength) * outputFeatureMapsDistance + ((id % outputFeatureMapLength) / outputColumns) * outputFeatureMapRowsDistance + (id % outputColumns) * outputFeatureMapColumnsDistance);
136136
}
137137

138+
public boolean accept(Conv2DConnection c, ValuesProvider valuesProvider) {
139+
if (TensorFactory.batchSize(valuesProvider) != miniBatchSize) {
140+
return false;
141+
}
142+
143+
if (TensorFactory.tensor(c.getOutputLayer(), c, valuesProvider).getElements() != output) {
144+
return false;
145+
}
146+
147+
if (TensorFactory.tensor(Util.getOppositeLayer(c, c.getOutputLayer()), c, valuesProvider).getElements() != input) {
148+
return false;
149+
}
150+
151+
return true;
152+
}
153+
138154
/**
139155
* the actual convolution
140156
* @param weightsStartId

nn-core/src/main/java/com/github/neuralnetworks/calculation/neuronfunctions/AparapiMaxPooling2D.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import com.github.neuralnetworks.architecture.Subsampling2DConnection;
88
import com.github.neuralnetworks.calculation.ConnectionCalculator;
99
import com.github.neuralnetworks.calculation.memory.ValuesProvider;
10-
import com.github.neuralnetworks.util.TensorFactory;
1110

1211
/**
1312
* Max pooling
@@ -20,7 +19,7 @@ public class AparapiMaxPooling2D implements ConnectionCalculator {
2019

2120
@Override
2221
public void calculate(List<Connections> connections, ValuesProvider valuesProvider, Layer targetLayer) {
23-
if (cc == null || cc.getMiniBatchSize() != TensorFactory.batchSize(valuesProvider)) {
22+
if (cc == null || !cc.accept((Subsampling2DConnection) connections.get(0), valuesProvider)) {
2423
cc = new AparapiMaxPooling2DCC((Subsampling2DConnection) connections.get(0), valuesProvider, targetLayer);
2524
}
2625

nn-core/src/main/java/com/github/neuralnetworks/calculation/neuronfunctions/AparapiStochasticPooling2D.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import com.github.neuralnetworks.architecture.Subsampling2DConnection;
88
import com.github.neuralnetworks.calculation.ConnectionCalculator;
99
import com.github.neuralnetworks.calculation.memory.ValuesProvider;
10-
import com.github.neuralnetworks.util.TensorFactory;
1110

1211
/**
1312
* Stochastic pooling
@@ -20,7 +19,7 @@ public class AparapiStochasticPooling2D implements ConnectionCalculator {
2019

2120
@Override
2221
public void calculate(List<Connections> connections, ValuesProvider valuesProvider, Layer targetLayer) {
23-
if (cc == null || cc.getMiniBatchSize() != TensorFactory.batchSize(valuesProvider)) {
22+
if (cc == null || !cc.accept((Subsampling2DConnection) connections.get(0), valuesProvider)) {
2423
cc = new AparapiStochasticPooling2DCC((Subsampling2DConnection) connections.get(0), valuesProvider, targetLayer);
2524
}
2625

nn-core/src/main/java/com/github/neuralnetworks/calculation/neuronfunctions/AparapiSubsampling2D.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,22 @@ public void run() {
154154
outputStartIndex + fm * outputFeatureMapsDistance + fmRow * outputFeatureMapRowsDistance + fmCol * outputFeatureMapColumnsDistance);
155155
}
156156

157+
public boolean accept(Subsampling2DConnection c, ValuesProvider valuesProvider) {
158+
if (TensorFactory.batchSize(valuesProvider) != miniBatchSize) {
159+
return false;
160+
}
161+
162+
if (TensorFactory.tensor(c.getOutputLayer(), c, valuesProvider).getElements() != output) {
163+
return false;
164+
}
165+
166+
if (TensorFactory.tensor(Util.getOppositeLayer(c, c.getOutputLayer()), c, valuesProvider).getElements() != input) {
167+
return false;
168+
}
169+
170+
return true;
171+
}
172+
157173
/**
158174
* This is where the subsampling happens
159175
*/

nn-core/src/main/java/com/github/neuralnetworks/calculation/neuronfunctions/ConnectionCalculatorConv.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
import com.github.neuralnetworks.util.Util;
1414

1515
/**
16-
* Default implementation of Connection calculator for convolutional/subsampling
17-
* layers
16+
* Default implementation of Connection calculator for convolutional/subsampling layers
1817
*/
1918
public class ConnectionCalculatorConv implements ConnectionCalculator {
2019

@@ -41,7 +40,7 @@ public void calculate(List<Connections> connections, ValuesProvider valuesProvid
4140

4241
if (c != null) {
4342
// currently works only as a feedforward (including bp)
44-
if (inputFunction == null || miniBatchSize != TensorFactory.batchSize(valuesProvider)) {
43+
if (inputFunction == null || !inputFunction.accept(c, valuesProvider)) {
4544
miniBatchSize = TensorFactory.batchSize(valuesProvider);
4645
inputFunction = createInputFunction(c, valuesProvider, targetLayer);
4746
}

nn-core/src/main/java/com/github/neuralnetworks/training/events/LogTrainingListener.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public void handleEvent(TrainingEvent event) {
6565

6666
StringBuilder sb = new StringBuilder();
6767
sb.append(((finishTime - startTime) / 1000f) + " s total time" + s);
68-
sb.append((miniBatchTotalTime / (miniBatches * 1000f)) + " s per minibatch of " + miniBatches + " mini batches" + s);
68+
sb.append((miniBatchTotalTime / (miniBatches * 1000f)) + " s per minibatch of " + miniBatches + " batches" + s);
6969
if (event instanceof TestingFinishedEvent) {
7070
Trainer<?> t = (Trainer<?>) event.getSource();
7171
OutputError oe = t.getOutputError();
@@ -83,7 +83,7 @@ public void handleEvent(TrainingEvent event) {
8383
String s = System.getProperty("line.separator");
8484

8585
if (miniBatchTime / 5000 > 0 && (logMiniBatches || (isTesting && logTestResults))) {
86-
sb.append(miniBatches + " minibatches in " + (miniBatchTotalTime / 1000f) + " s" + s);
86+
sb.append(miniBatches + " batches in " + (miniBatchTotalTime / 1000f) + " s" + s);
8787
miniBatchTime = 0;
8888
}
8989

nn-core/src/test/java/com/github/neuralnetworks/test/CNNTest.java

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -677,35 +677,4 @@ public void testCNNMLPFF() {
677677

678678
assertTrue(Arrays.equals(cnnvp.get(cnn.getOutputLayer()).getElements(), mlpvp.get(mlp.getOutputLayer()).getElements()));
679679
}
680-
681-
@Test
682-
public void testCNNMLPBP() {
683-
Environment.getInstance().setExecutionMode(EXECUTION_MODE.SEQ);
684-
685-
Environment.getInstance().setUseDataSharedMemory(false);
686-
687-
// CNN
688-
NeuralNetworkImpl cnn = NNFactory.convNN(new int[][] { { 2, 1, 1 }, { 1, 1 }, {1} }, false);
689-
cnn.setLayerCalculator(NNFactory.lcSigmoid(cnn, null));
690-
NNFactory.lcMaxPooling(cnn);
691-
FullyConnected cnnfc = (FullyConnected) cnn.getOutputLayer().getConnections().get(0);
692-
cnnfc.getWeights().set(0.05f, 0, 0);
693-
cnnfc.getWeights().set(0.08f, 0, 1);
694-
695-
// MLP
696-
NeuralNetworkImpl mlp = NNFactory.mlpSigmoid(new int[] { 2, 1 }, false);
697-
FullyConnected mlpfc = (FullyConnected) mlp.getOutputLayer().getConnections().get(0);
698-
mlpfc.getWeights().set(0.05f, 0, 0);
699-
mlpfc.getWeights().set(0.08f, 0, 1);
700-
701-
// compare bp
702-
SimpleInputProvider inputProvider = new SimpleInputProvider(new float[][] { { 0.35f, 0.9f }, { 0.8f, 0.2f } }, new float[][] { { 0.5f }, { 0.8f } });
703-
BackPropagationTrainer<?> cnnbpt = TrainerFactory.backPropagation(cnn, inputProvider, null, null, null, 1f, 0f, 0f, 0f, 0f, 1, 1, 20);
704-
cnnbpt.train();
705-
706-
BackPropagationTrainer<?> mlpbpt = TrainerFactory.backPropagation(mlp, inputProvider, null, null, null, 1f, 0f, 0f, 0f, 0f, 1, 1, 20);
707-
mlpbpt.train();
708-
709-
assertTrue(Arrays.equals(cnnfc.getWeights().getElements(), mlpfc.getWeights().getElements()));
710-
}
711680
}

nn-samples/src/test/java/com/github/neuralnetworks/samples/test/XorTest.java

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package com.github.neuralnetworks.samples.test;
22

33
import static org.junit.Assert.assertEquals;
4+
import static org.junit.Assert.assertTrue;
5+
6+
import java.util.Arrays;
47

58
import org.junit.Test;
69

710
import com.amd.aparapi.Kernel.EXECUTION_MODE;
11+
import com.github.neuralnetworks.architecture.FullyConnected;
812
import com.github.neuralnetworks.architecture.NeuralNetworkImpl;
913
import com.github.neuralnetworks.architecture.types.NNFactory;
1014
import com.github.neuralnetworks.samples.xor.XorOutputError;
@@ -26,15 +30,20 @@ public void testMLPSigmoidBP() {
2630
Environment.getInstance().setExecutionMode(EXECUTION_MODE.SEQ);
2731

2832
// create multi layer perceptron with one hidden layer and bias
29-
Environment.getInstance().setUseWeightsSharedMemory(true);
30-
NeuralNetworkImpl mlp = NNFactory.mlpSigmoid(new int[] { 2, 8, 1 }, true);
33+
Environment.getInstance().setUseWeightsSharedMemory(false);
34+
Environment.getInstance().setUseDataSharedMemory(false);
35+
//NeuralNetworkImpl mlp = NNFactory.mlpSigmoid(new int[] { 2, 4, 1 }, false);
36+
NeuralNetworkImpl mlp = NNFactory.convNN(new int[][] { { 2, 1, 1 }, { 1, 1 }, { 4 }, {1} }, false);
37+
//NeuralNetworkImpl mlp = NNFactory.convNN(new int[][] { {2, 1, 1}, {4}, {1} }, false);
38+
mlp.setLayerCalculator(NNFactory.lcSigmoid(mlp, null));
39+
NNFactory.lcMaxPooling(mlp);
40+
3141

3242
// create training and testing input providers
33-
SimpleInputProvider trainingInput = new SimpleInputProvider(new float[][] { {0, 0}, {0, 1}, {1, 0}, {1, 1} }, new float[][] { {0}, {1}, {1}, {0} });
34-
SimpleInputProvider testingInput = new SimpleInputProvider(new float[][] { {0, 0}, {0, 1}, {1, 0}, {1, 1} }, new float[][] { {0}, {1}, {1}, {0} });
43+
SimpleInputProvider input = new SimpleInputProvider(new float[][] { {0, 0}, {0, 1}, {1, 0}, {1, 1} }, new float[][] { {0}, {1}, {1}, {0} });
3544

3645
// create backpropagation trainer for the network
37-
BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(mlp, trainingInput, testingInput, new XorOutputError(), new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 1f, 0.5f, 0f, 0f, 0f, 1, 1, 2500);
46+
BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(mlp, input, input, new XorOutputError(), new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 1f, 0.5f, 0f, 0f, 0f, 1, 1, 50000);
3847

3948
// add logging
4049
bpt.addEventListener(new LogTrainingListener(Thread.currentThread().getStackTrace()[1].getMethodName()));
@@ -50,4 +59,62 @@ public void testMLPSigmoidBP() {
5059

5160
assertEquals(0, bpt.getOutputError().getTotalNetworkError(), 0.1);
5261
}
62+
63+
@Test
64+
public void testCNNMLPBP() {
65+
Environment.getInstance().setExecutionMode(EXECUTION_MODE.SEQ);
66+
67+
Environment.getInstance().setUseDataSharedMemory(true);
68+
Environment.getInstance().setUseWeightsSharedMemory(true);
69+
70+
// CNN
71+
NeuralNetworkImpl cnn = NNFactory.convNN(new int[][] { { 2, 1, 1 }, { 1, 1 }, { 4 }, {1} }, false);
72+
cnn.setLayerCalculator(NNFactory.lcSigmoid(cnn, null));
73+
NNFactory.lcMaxPooling(cnn);
74+
FullyConnected cnnfci = (FullyConnected) cnn.getOutputLayer().getConnections().get(0).getInputLayer().getConnections().get(0);
75+
cnnfci.getWeights().set(0.02f, 0, 0);
76+
cnnfci.getWeights().set(0.01f, 1, 0);
77+
cnnfci.getWeights().set(0.03f, 2, 0);
78+
cnnfci.getWeights().set(0.001f, 3, 0);
79+
cnnfci.getWeights().set(0.005f, 0, 1);
80+
cnnfci.getWeights().set(0.04f, 1, 1);
81+
cnnfci.getWeights().set(0.02f, 2, 1);
82+
cnnfci.getWeights().set(0.009f, 3, 1);
83+
84+
FullyConnected cnnfco = (FullyConnected) cnn.getOutputLayer().getConnections().get(0);
85+
cnnfco.getWeights().set(0.05f, 0, 0);
86+
cnnfco.getWeights().set(0.08f, 0, 1);
87+
88+
// MLP
89+
NeuralNetworkImpl mlp = NNFactory.mlpSigmoid(new int[] { 2, 4, 1 }, false);
90+
91+
FullyConnected mlpfci = (FullyConnected) mlp.getOutputLayer().getConnections().get(0).getInputLayer().getConnections().get(0);
92+
mlpfci.getWeights().set(0.02f, 0, 0);
93+
mlpfci.getWeights().set(0.01f, 1, 0);
94+
mlpfci.getWeights().set(0.03f, 2, 0);
95+
mlpfci.getWeights().set(0.001f, 3, 0);
96+
mlpfci.getWeights().set(0.005f, 0, 1);
97+
mlpfci.getWeights().set(0.04f, 1, 1);
98+
mlpfci.getWeights().set(0.02f, 2, 1);
99+
mlpfci.getWeights().set(0.009f, 3, 1);
100+
101+
FullyConnected mlpfco = (FullyConnected) mlp.getOutputLayer().getConnections().get(0);
102+
mlpfco.getWeights().set(0.05f, 0, 0);
103+
mlpfco.getWeights().set(0.08f, 0, 1);
104+
105+
// compare bp
106+
SimpleInputProvider inputProvider = new SimpleInputProvider(new float[][] { {0, 0}, {0, 1}, {1, 0}, {1, 1} }, new float[][] { {0}, {1}, {1}, {0} });
107+
108+
BackPropagationTrainer<?> mlpbpt = TrainerFactory.backPropagation(mlp, inputProvider, inputProvider, new XorOutputError(), null, 1f, 0f, 0f, 0f, 0f, 1, 1, 10000);
109+
mlpbpt.train();
110+
mlpbpt.test();
111+
112+
BackPropagationTrainer<?> cnnbpt = TrainerFactory.backPropagation(cnn, inputProvider, inputProvider, new XorOutputError(), null, 1f, 0f, 0f, 0f, 0f, 1, 1, 10000);
113+
cnnbpt.train();
114+
cnnbpt.test();
115+
116+
assertEquals(mlpbpt.getOutputError().getTotalNetworkError(), cnnbpt.getOutputError().getTotalNetworkError(), 0);
117+
assertTrue(Arrays.equals(cnnfco.getWeights().getElements(), mlpfco.getWeights().getElements()));
118+
assertTrue(Arrays.equals(cnnfci.getWeights().getElements(), mlpfci.getWeights().getElements()));
119+
}
53120
}

0 commit comments

Comments
 (0)