|
7 | 7 |
|
8 | 8 | import com.amd.aparapi.Kernel.EXECUTION_MODE; |
9 | 9 | import com.github.neuralnetworks.architecture.NeuralNetworkImpl; |
10 | | -import com.github.neuralnetworks.architecture.types.Autoencoder; |
11 | 10 | import com.github.neuralnetworks.architecture.types.NNFactory; |
12 | | -import com.github.neuralnetworks.architecture.types.RBM; |
13 | 11 | import com.github.neuralnetworks.input.MultipleNeuronsOutputError; |
14 | 12 | import com.github.neuralnetworks.input.ScalingInputFunction; |
15 | 13 | import com.github.neuralnetworks.samples.mnist.MnistInputProvider; |
16 | | -import com.github.neuralnetworks.training.Trainer; |
17 | 14 | import com.github.neuralnetworks.training.TrainerFactory; |
18 | 15 | import com.github.neuralnetworks.training.backpropagation.BackPropagationTrainer; |
19 | 16 | import com.github.neuralnetworks.training.events.LogTrainingListener; |
20 | 17 | import com.github.neuralnetworks.training.random.MersenneTwisterRandomInitializer; |
21 | 18 | import com.github.neuralnetworks.training.random.NNRandomInitializer; |
22 | | -import com.github.neuralnetworks.training.rbm.AparapiCDTrainer; |
23 | 19 | import com.github.neuralnetworks.util.Environment; |
24 | 20 |
|
25 | 21 | /** |
@@ -72,44 +68,6 @@ public void testSigmoidHiddenBP() { |
72 | 68 | assertEquals(0, bpt.getOutputError().getTotalNetworkError(), 0.1); |
73 | 69 | } |
74 | 70 |
|
75 | | - @Test |
76 | | - public void testRBM() { |
77 | | - RBM rbm = NNFactory.rbm(784, 10, false); |
78 | | - MnistInputProvider trainInputProvider = new MnistInputProvider("train-images.idx3-ubyte", "train-labels.idx1-ubyte"); |
79 | | - trainInputProvider.addInputModifier(new ScalingInputFunction(255)); |
80 | | - MnistInputProvider testInputProvider = new MnistInputProvider("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte"); |
81 | | - testInputProvider.addInputModifier(new ScalingInputFunction(255)); |
82 | | - |
83 | | - AparapiCDTrainer t = TrainerFactory.cdSigmoidBinaryTrainer(rbm, trainInputProvider, testInputProvider, new MultipleNeuronsOutputError(), new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.01f, 0.5f, 0f, 0f, 1, 1, 1, false); |
84 | | - |
85 | | - t.addEventListener(new LogTrainingListener(Thread.currentThread().getStackTrace()[1].getMethodName(), false, true)); |
86 | | - Environment.getInstance().setExecutionMode(EXECUTION_MODE.CPU); |
87 | | - t.train(); |
88 | | - t.test(); |
89 | | - |
90 | | - assertEquals(0, t.getOutputError().getTotalNetworkError(), 0.8); |
91 | | - } |
92 | | - |
93 | | - @Test |
94 | | - public void testAE() { |
95 | | - Autoencoder nn = NNFactory.autoencoderSigmoid(784, 10, true); |
96 | | - |
97 | | - MnistInputProvider trainInputProvider = new MnistInputProvider("train-images.idx3-ubyte", "train-labels.idx1-ubyte"); |
98 | | - trainInputProvider.addInputModifier(new ScalingInputFunction(255)); |
99 | | - MnistInputProvider testInputProvider = new MnistInputProvider("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte"); |
100 | | - testInputProvider.addInputModifier(new ScalingInputFunction(255)); |
101 | | - |
102 | | - Trainer<?> t = TrainerFactory.backPropagationAutoencoder(nn, trainInputProvider, testInputProvider, new MultipleNeuronsOutputError(), new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.01f, 0.5f, 0f, 0f, 0f, 1, 1000, 1); |
103 | | - |
104 | | - t.addEventListener(new LogTrainingListener(Thread.currentThread().getStackTrace()[1].getMethodName(), false, true)); |
105 | | - Environment.getInstance().setExecutionMode(EXECUTION_MODE.CPU); |
106 | | - t.train(); |
107 | | - nn.removeLayer(nn.getOutputLayer()); |
108 | | - t.test(); |
109 | | - |
110 | | - assertEquals(0, t.getOutputError().getTotalNetworkError(), 0.1); |
111 | | - } |
112 | | - |
113 | 71 | @Test |
114 | 72 | public void testLeNetSmall() { |
115 | 73 | // cpu execution mode |
@@ -180,6 +138,9 @@ public void testLeNetTiny() { |
180 | 138 | */ |
181 | 139 | @Test |
182 | 140 | public void testLeNetTiny2() { |
| 141 | + Environment.getInstance().setUseDataSharedMemory(false); |
| 142 | + Environment.getInstance().setUseWeightsSharedMemory(false); |
| 143 | + |
183 | 144 | // very simple convolutional network with a single convolutional layer with 6 5x5 filters and a single 2x2 max pooling layer |
184 | 145 | NeuralNetworkImpl nn = NNFactory.convNN(new int[][] { { 28, 28, 1 }, { 5, 5, 6, 1 }, {2, 2}, {10} }, true); |
185 | 146 | nn.setLayerCalculator(NNFactory.lcSigmoid(nn, null)); |
|
0 commit comments