11package com .github .neuralnetworks .samples .test ;
22
33import static org .junit .Assert .assertEquals ;
4+ import static org .junit .Assert .assertTrue ;
5+
6+ import java .util .Arrays ;
47
58import org .junit .Test ;
69
710import com .amd .aparapi .Kernel .EXECUTION_MODE ;
11+ import com .github .neuralnetworks .architecture .FullyConnected ;
812import com .github .neuralnetworks .architecture .NeuralNetworkImpl ;
913import com .github .neuralnetworks .architecture .types .NNFactory ;
1014import com .github .neuralnetworks .samples .xor .XorOutputError ;
@@ -26,15 +30,20 @@ public void testMLPSigmoidBP() {
2630 Environment .getInstance ().setExecutionMode (EXECUTION_MODE .SEQ );
2731
2832 // create multi layer perceptron with one hidden layer and bias
29- Environment .getInstance ().setUseWeightsSharedMemory (true );
30- NeuralNetworkImpl mlp = NNFactory .mlpSigmoid (new int [] { 2 , 8 , 1 }, true );
33+ Environment .getInstance ().setUseWeightsSharedMemory (false );
34+ Environment .getInstance ().setUseDataSharedMemory (false );
35+ //NeuralNetworkImpl mlp = NNFactory.mlpSigmoid(new int[] { 2, 4, 1 }, false);
36+ NeuralNetworkImpl mlp = NNFactory .convNN (new int [][] { { 2 , 1 , 1 }, { 1 , 1 }, { 4 }, {1 } }, false );
37+ //NeuralNetworkImpl mlp = NNFactory.convNN(new int[][] { {2, 1, 1}, {4}, {1} }, false);
38+ mlp .setLayerCalculator (NNFactory .lcSigmoid (mlp , null ));
39+ NNFactory .lcMaxPooling (mlp );
40+
3141
3242 // create training and testing input providers
33- SimpleInputProvider trainingInput = new SimpleInputProvider (new float [][] { {0 , 0 }, {0 , 1 }, {1 , 0 }, {1 , 1 } }, new float [][] { {0 }, {1 }, {1 }, {0 } });
34- SimpleInputProvider testingInput = new SimpleInputProvider (new float [][] { {0 , 0 }, {0 , 1 }, {1 , 0 }, {1 , 1 } }, new float [][] { {0 }, {1 }, {1 }, {0 } });
43+ SimpleInputProvider input = new SimpleInputProvider (new float [][] { {0 , 0 }, {0 , 1 }, {1 , 0 }, {1 , 1 } }, new float [][] { {0 }, {1 }, {1 }, {0 } });
3544
3645 // create backpropagation trainer for the network
37- BackPropagationTrainer <?> bpt = TrainerFactory .backPropagation (mlp , trainingInput , testingInput , new XorOutputError (), new NNRandomInitializer (new MersenneTwisterRandomInitializer (-0.01f , 0.01f )), 1f , 0.5f , 0f , 0f , 0f , 1 , 1 , 2500 );
46+ BackPropagationTrainer <?> bpt = TrainerFactory .backPropagation (mlp , input , input , new XorOutputError (), new NNRandomInitializer (new MersenneTwisterRandomInitializer (-0.01f , 0.01f )), 1f , 0.5f , 0f , 0f , 0f , 1 , 1 , 50000 );
3847
3948 // add logging
4049 bpt .addEventListener (new LogTrainingListener (Thread .currentThread ().getStackTrace ()[1 ].getMethodName ()));
@@ -50,4 +59,62 @@ public void testMLPSigmoidBP() {
5059
5160 assertEquals (0 , bpt .getOutputError ().getTotalNetworkError (), 0.1 );
5261 }
62+
63+ @ Test
64+ public void testCNNMLPBP () {
65+ Environment .getInstance ().setExecutionMode (EXECUTION_MODE .SEQ );
66+
67+ Environment .getInstance ().setUseDataSharedMemory (true );
68+ Environment .getInstance ().setUseWeightsSharedMemory (true );
69+
70+ // CNN
71+ NeuralNetworkImpl cnn = NNFactory .convNN (new int [][] { { 2 , 1 , 1 }, { 1 , 1 }, { 4 }, {1 } }, false );
72+ cnn .setLayerCalculator (NNFactory .lcSigmoid (cnn , null ));
73+ NNFactory .lcMaxPooling (cnn );
74+ FullyConnected cnnfci = (FullyConnected ) cnn .getOutputLayer ().getConnections ().get (0 ).getInputLayer ().getConnections ().get (0 );
75+ cnnfci .getWeights ().set (0.02f , 0 , 0 );
76+ cnnfci .getWeights ().set (0.01f , 1 , 0 );
77+ cnnfci .getWeights ().set (0.03f , 2 , 0 );
78+ cnnfci .getWeights ().set (0.001f , 3 , 0 );
79+ cnnfci .getWeights ().set (0.005f , 0 , 1 );
80+ cnnfci .getWeights ().set (0.04f , 1 , 1 );
81+ cnnfci .getWeights ().set (0.02f , 2 , 1 );
82+ cnnfci .getWeights ().set (0.009f , 3 , 1 );
83+
84+ FullyConnected cnnfco = (FullyConnected ) cnn .getOutputLayer ().getConnections ().get (0 );
85+ cnnfco .getWeights ().set (0.05f , 0 , 0 );
86+ cnnfco .getWeights ().set (0.08f , 0 , 1 );
87+
88+ // MLP
89+ NeuralNetworkImpl mlp = NNFactory .mlpSigmoid (new int [] { 2 , 4 , 1 }, false );
90+
91+ FullyConnected mlpfci = (FullyConnected ) mlp .getOutputLayer ().getConnections ().get (0 ).getInputLayer ().getConnections ().get (0 );
92+ mlpfci .getWeights ().set (0.02f , 0 , 0 );
93+ mlpfci .getWeights ().set (0.01f , 1 , 0 );
94+ mlpfci .getWeights ().set (0.03f , 2 , 0 );
95+ mlpfci .getWeights ().set (0.001f , 3 , 0 );
96+ mlpfci .getWeights ().set (0.005f , 0 , 1 );
97+ mlpfci .getWeights ().set (0.04f , 1 , 1 );
98+ mlpfci .getWeights ().set (0.02f , 2 , 1 );
99+ mlpfci .getWeights ().set (0.009f , 3 , 1 );
100+
101+ FullyConnected mlpfco = (FullyConnected ) mlp .getOutputLayer ().getConnections ().get (0 );
102+ mlpfco .getWeights ().set (0.05f , 0 , 0 );
103+ mlpfco .getWeights ().set (0.08f , 0 , 1 );
104+
105+ // compare bp
106+ SimpleInputProvider inputProvider = new SimpleInputProvider (new float [][] { {0 , 0 }, {0 , 1 }, {1 , 0 }, {1 , 1 } }, new float [][] { {0 }, {1 }, {1 }, {0 } });
107+
108+ BackPropagationTrainer <?> mlpbpt = TrainerFactory .backPropagation (mlp , inputProvider , inputProvider , new XorOutputError (), null , 1f , 0f , 0f , 0f , 0f , 1 , 1 , 10000 );
109+ mlpbpt .train ();
110+ mlpbpt .test ();
111+
112+ BackPropagationTrainer <?> cnnbpt = TrainerFactory .backPropagation (cnn , inputProvider , inputProvider , new XorOutputError (), null , 1f , 0f , 0f , 0f , 0f , 1 , 1 , 10000 );
113+ cnnbpt .train ();
114+ cnnbpt .test ();
115+
116+ assertEquals (mlpbpt .getOutputError ().getTotalNetworkError (), cnnbpt .getOutputError ().getTotalNetworkError (), 0 );
117+ assertTrue (Arrays .equals (cnnfco .getWeights ().getElements (), mlpfco .getWeights ().getElements ()));
118+ assertTrue (Arrays .equals (cnnfci .getWeights ().getElements (), mlpfci .getWeights ().getElements ()));
119+ }
53120}
0 commit comments