Skip to content

Commit 331531f

Browse files
committed
Fix
Signed-off-by: Alex Black <[email protected]>
1 parent a0f6e8e commit 331531f

File tree

1 file changed

+97
-0
lines changed
  • dl4j-examples_javafx/src/main/java/org/deeplearning4j/examples/userInterface/util

1 file changed

+97
-0
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2015-2019 Skymind, Inc.
3+
*
4+
* This program and the accompanying materials are made available under the
5+
* terms of the Apache License, Version 2.0 which is available at
6+
* https://www.apache.org/licenses/LICENSE-2.0.
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11+
* License for the specific language governing permissions and limitations
12+
* under the License.
13+
*
14+
* SPDX-License-Identifier: Apache-2.0
15+
******************************************************************************/
16+
17+
package org.deeplearning4j.examples.userInterface.util;
18+
19+
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
20+
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
21+
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
22+
import org.deeplearning4j.nn.conf.inputs.InputType;
23+
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
24+
import org.deeplearning4j.nn.conf.layers.DenseLayer;
25+
import org.deeplearning4j.nn.conf.layers.OutputLayer;
26+
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
27+
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
28+
import org.deeplearning4j.nn.weights.WeightInit;
29+
import org.nd4j.linalg.activations.Activation;
30+
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
31+
import org.nd4j.linalg.learning.config.Adam;
32+
import org.nd4j.linalg.lossfunctions.LossFunctions;
33+
34+
import java.io.IOException;
35+
36+
/**
37+
* Created by Alex on 11/11/2016.
38+
*/
39+
public class UIExampleUtils {
40+
41+
public static MultiLayerNetwork getMnistNetwork(){
42+
43+
int nChannels = 1; // Number of input channels
44+
int outputNum = 10; // The number of possible outcomes
45+
int seed = 123; //
46+
47+
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
48+
.seed(seed)
49+
.l2(0.0005)
50+
.weightInit(WeightInit.XAVIER)
51+
.updater(new Adam(0.001))
52+
.list()
53+
.layer(new ConvolutionLayer.Builder(5, 5)
54+
//nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
55+
.nIn(nChannels)
56+
.stride(1, 1)
57+
.nOut(20)
58+
.activation(Activation.LEAKYRELU)
59+
.build())
60+
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
61+
.kernelSize(2,2)
62+
.stride(2,2)
63+
.build())
64+
.layer(new ConvolutionLayer.Builder(5, 5)
65+
//Note that nIn need not be specified in later layers
66+
.stride(1, 1)
67+
.nOut(50)
68+
.activation(Activation.LEAKYRELU)
69+
.build())
70+
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
71+
.kernelSize(2,2)
72+
.stride(2,2)
73+
.build())
74+
.layer(new DenseLayer.Builder().activation(Activation.LEAKYRELU).nOut(500).build())
75+
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
76+
.nOut(outputNum)
77+
.activation(Activation.SOFTMAX)
78+
.build())
79+
.setInputType(InputType.convolutionalFlat(28,28,1))
80+
.build();
81+
82+
83+
MultiLayerNetwork net = new MultiLayerNetwork(conf);
84+
net.init();
85+
86+
return net;
87+
}
88+
89+
public static DataSetIterator getMnistData(){
90+
try{
91+
return new MnistDataSetIterator(64,true,12345);
92+
}catch (IOException e){
93+
throw new RuntimeException(e);
94+
}
95+
}
96+
97+
}

0 commit comments

Comments
 (0)