Skip to content

Commit 8f20c3c

Browse files
committed
mnist fix
1 parent c54e424 commit 8f20c3c

File tree

1 file changed

+3
-42
lines changed
  • nn-samples/src/test/java/com/github/neuralnetworks/samples/test

1 file changed

+3
-42
lines changed

nn-samples/src/test/java/com/github/neuralnetworks/samples/test/MnistTest.java

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,15 @@
77

88
import com.amd.aparapi.Kernel.EXECUTION_MODE;
99
import com.github.neuralnetworks.architecture.NeuralNetworkImpl;
10-
import com.github.neuralnetworks.architecture.types.Autoencoder;
1110
import com.github.neuralnetworks.architecture.types.NNFactory;
12-
import com.github.neuralnetworks.architecture.types.RBM;
1311
import com.github.neuralnetworks.input.MultipleNeuronsOutputError;
1412
import com.github.neuralnetworks.input.ScalingInputFunction;
1513
import com.github.neuralnetworks.samples.mnist.MnistInputProvider;
16-
import com.github.neuralnetworks.training.Trainer;
1714
import com.github.neuralnetworks.training.TrainerFactory;
1815
import com.github.neuralnetworks.training.backpropagation.BackPropagationTrainer;
1916
import com.github.neuralnetworks.training.events.LogTrainingListener;
2017
import com.github.neuralnetworks.training.random.MersenneTwisterRandomInitializer;
2118
import com.github.neuralnetworks.training.random.NNRandomInitializer;
22-
import com.github.neuralnetworks.training.rbm.AparapiCDTrainer;
2319
import com.github.neuralnetworks.util.Environment;
2420

2521
/**
@@ -72,44 +68,6 @@ public void testSigmoidHiddenBP() {
7268
assertEquals(0, bpt.getOutputError().getTotalNetworkError(), 0.1);
7369
}
7470

75-
@Test
76-
public void testRBM() {
77-
RBM rbm = NNFactory.rbm(784, 10, false);
78-
MnistInputProvider trainInputProvider = new MnistInputProvider("train-images.idx3-ubyte", "train-labels.idx1-ubyte");
79-
trainInputProvider.addInputModifier(new ScalingInputFunction(255));
80-
MnistInputProvider testInputProvider = new MnistInputProvider("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte");
81-
testInputProvider.addInputModifier(new ScalingInputFunction(255));
82-
83-
AparapiCDTrainer t = TrainerFactory.cdSigmoidBinaryTrainer(rbm, trainInputProvider, testInputProvider, new MultipleNeuronsOutputError(), new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.01f, 0.5f, 0f, 0f, 1, 1, 1, false);
84-
85-
t.addEventListener(new LogTrainingListener(Thread.currentThread().getStackTrace()[1].getMethodName(), false, true));
86-
Environment.getInstance().setExecutionMode(EXECUTION_MODE.CPU);
87-
t.train();
88-
t.test();
89-
90-
assertEquals(0, t.getOutputError().getTotalNetworkError(), 0.8);
91-
}
92-
93-
@Test
94-
public void testAE() {
95-
Autoencoder nn = NNFactory.autoencoderSigmoid(784, 10, true);
96-
97-
MnistInputProvider trainInputProvider = new MnistInputProvider("train-images.idx3-ubyte", "train-labels.idx1-ubyte");
98-
trainInputProvider.addInputModifier(new ScalingInputFunction(255));
99-
MnistInputProvider testInputProvider = new MnistInputProvider("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte");
100-
testInputProvider.addInputModifier(new ScalingInputFunction(255));
101-
102-
Trainer<?> t = TrainerFactory.backPropagationAutoencoder(nn, trainInputProvider, testInputProvider, new MultipleNeuronsOutputError(), new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.01f, 0.5f, 0f, 0f, 0f, 1, 1000, 1);
103-
104-
t.addEventListener(new LogTrainingListener(Thread.currentThread().getStackTrace()[1].getMethodName(), false, true));
105-
Environment.getInstance().setExecutionMode(EXECUTION_MODE.CPU);
106-
t.train();
107-
nn.removeLayer(nn.getOutputLayer());
108-
t.test();
109-
110-
assertEquals(0, t.getOutputError().getTotalNetworkError(), 0.1);
111-
}
112-
11371
@Test
11472
public void testLeNetSmall() {
11573
// cpu execution mode
@@ -180,6 +138,9 @@ public void testLeNetTiny() {
180138
*/
181139
@Test
182140
public void testLeNetTiny2() {
141+
Environment.getInstance().setUseDataSharedMemory(false);
142+
Environment.getInstance().setUseWeightsSharedMemory(false);
143+
183144
// very simple convolutional network with a single convolutional layer with 6 5x5 filters and a single 2x2 max pooling layer
184145
NeuralNetworkImpl nn = NNFactory.convNN(new int[][] { { 28, 28, 1 }, { 5, 5, 6, 1 }, {2, 2}, {10} }, true);
185146
nn.setLayerCalculator(NNFactory.lcSigmoid(nn, null));

0 commit comments

Comments
 (0)