1
- /** *****************************************************************************
1
+ /* *****************************************************************************
2
2
* Copyright (c) 2015-2019 Skymind, Inc.
3
3
*
4
4
* This program and the accompanying materials are made available under the
34
34
import org .deeplearning4j .nn .conf .MultiLayerConfiguration ;
35
35
import org .deeplearning4j .nn .conf .NeuralNetConfiguration ;
36
36
import org .deeplearning4j .nn .conf .distribution .Distribution ;
37
- import org .deeplearning4j .nn .conf .distribution .GaussianDistribution ;
38
37
import org .deeplearning4j .nn .conf .distribution .NormalDistribution ;
39
38
import org .deeplearning4j .nn .conf .inputs .InputType ;
40
39
import org .deeplearning4j .nn .conf .inputs .InvalidInputTypeException ;
41
40
import org .deeplearning4j .nn .conf .layers .*;
42
41
import org .deeplearning4j .nn .multilayer .MultiLayerNetwork ;
43
42
import org .deeplearning4j .nn .weights .WeightInit ;
43
+ import org .deeplearning4j .nn .weights .WeightInitDistribution ;
44
44
import org .deeplearning4j .optimize .api .InvocationType ;
45
45
import org .deeplearning4j .optimize .listeners .EvaluativeListener ;
46
46
import org .deeplearning4j .optimize .listeners .ScoreIterationListener ;
61
61
import java .io .File ;
62
62
import java .util .Arrays ;
63
63
import java .util .List ;
64
+ import java .util .Objects ;
64
65
import java .util .Random ;
65
66
66
67
import static java .lang .Math .toIntExact ;
@@ -92,43 +93,42 @@ public class AnimalsClassification {
92
93
protected static long seed = 42 ;
93
94
protected static Random rng = new Random (seed );
94
95
protected static int epochs = 100 ;
95
- protected static double splitTrainTest = 0.8 ;
96
96
protected static boolean save = false ;
97
- protected static int maxPathsPerLabel =18 ;
98
97
99
- protected static String modelType = "LeNet" ; // LeNet, AlexNet or Custom but you need to fill it out
100
98
private int numLabels ;
101
99
102
100
public static String dataLocalPath ;
103
101
104
- public void run (String [] args ) throws Exception {
102
+ public void run () throws Exception {
105
103
106
104
dataLocalPath = DownloaderUtility .ANIMALS .Download ();
107
- /** cd
105
+ /* cd
108
106
* Data Setup -> organize and limit data file paths:
109
107
* - mainPath = path to image files
110
108
* - fileSplit = define basic dataset split with limits on format
111
109
* - pathFilter = define additional file load filter to limit size and balance batch content
112
- ** /
110
+ */
113
111
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator ();
114
112
File mainPath = new File (dataLocalPath );
115
113
FileSplit fileSplit = new FileSplit (mainPath , NativeImageLoader .ALLOWED_FORMATS , rng );
116
114
int numExamples = toIntExact (fileSplit .length ());
117
- numLabels = fileSplit .getRootDir ().listFiles (File ::isDirectory ).length ; //This only works if your root is clean: only label subdirs.
115
+ numLabels = Objects .requireNonNull (fileSplit .getRootDir ().listFiles (File ::isDirectory )).length ; //This only works if your root is clean: only label subdirs.
116
+ int maxPathsPerLabel = 18 ;
118
117
BalancedPathFilter pathFilter = new BalancedPathFilter (rng , labelMaker , numExamples , numLabels , maxPathsPerLabel );
119
118
120
- /**
119
+ /*
121
120
* Data Setup -> train test split
122
121
* - inputSplit = define train and test split
123
- **/
122
+ */
123
+ double splitTrainTest = 0.8 ;
124
124
InputSplit [] inputSplit = fileSplit .sample (pathFilter , splitTrainTest , 1 - splitTrainTest );
125
125
InputSplit trainData = inputSplit [0 ];
126
126
InputSplit testData = inputSplit [1 ];
127
127
128
- /**
128
+ /*
129
129
* Data Setup -> transformation
130
- * - Transform = how to tranform images and generate large dataset to train on
131
- ** /
130
+ * - Transform = how to transform images and generate large dataset to train on
131
+ */
132
132
ImageTransform flipTransform1 = new FlipImageTransform (rng );
133
133
ImageTransform flipTransform2 = new FlipImageTransform (new Random (123 ));
134
134
ImageTransform warpTransform = new WarpImageTransform (rng , 42 );
@@ -137,11 +137,12 @@ public void run(String[] args) throws Exception {
137
137
new Pair <>(flipTransform2 ,0.8 ),
138
138
new Pair <>(warpTransform ,0.5 ));
139
139
140
+ //noinspection ConstantConditions
140
141
ImageTransform transform = new PipelineImageTransform (pipeline ,shuffle );
141
- /**
142
+ /*
142
143
* Data Setup -> normalization
143
144
* - how to normalize images and generate large dataset to train on
144
- ** /
145
+ */
145
146
DataNormalization scaler = new ImagePreProcessingScaler (0 , 1 );
146
147
147
148
log .info ("Build model...." );
@@ -150,7 +151,10 @@ public void run(String[] args) throws Exception {
150
151
// MultiLayerNetwork network = new AlexNet(height, width, channels, numLabels, seed, iterations).init();
151
152
152
153
MultiLayerNetwork network ;
154
+ // LeNet, AlexNet or Custom but you need to fill it out
155
+ String modelType = "LeNet" ;
153
156
switch (modelType ) {
157
+ //noinspection ConstantConditions
154
158
case "LeNet" :
155
159
network = lenetModel ();
156
160
break ;
@@ -168,12 +172,12 @@ public void run(String[] args) throws Exception {
168
172
UIServer uiServer = UIServer .getInstance ();
169
173
StatsStorage statsStorage = new FileStatsStorage (new File (System .getProperty ("java.io.tmpdir" ), "ui-stats.dl4j" ));
170
174
uiServer .attach (statsStorage );
171
- /**
175
+ /*
172
176
* Data Setup -> define how to load data into net:
173
177
* - recordReader = the reader that loads and converts image data pass in inputSplit to initialize
174
178
* - dataIter = a generator that only loads one batch at a time into memory to save memory
175
179
* - trainIter = uses MultipleEpochsIterator to ensure model runs through the data for all epochs
176
- ** /
180
+ */
177
181
ImageRecordReader trainRR = new ImageRecordReader (height , width , channels , labelMaker );
178
182
DataSetIterator trainIter ;
179
183
@@ -221,6 +225,7 @@ public void run(String[] args) throws Exception {
221
225
log .info ("****************Example finished********************" );
222
226
}
223
227
228
+ @ SuppressWarnings ("SameParameterValue" )
224
229
private ConvolutionLayer convInit (String name , int in , int out , int [] kernel , int [] stride , int [] pad , double bias ) {
225
230
return new ConvolutionLayer .Builder (kernel , stride , pad ).name (name ).nIn (in ).nOut (out ).biasInit (bias ).build ();
226
231
}
@@ -229,6 +234,7 @@ private ConvolutionLayer conv3x3(String name, int out, double bias) {
229
234
return new ConvolutionLayer .Builder (new int []{3 ,3 }, new int [] {1 ,1 }, new int [] {1 ,1 }).name (name ).nOut (out ).biasInit (bias ).build ();
230
235
}
231
236
237
+ @ SuppressWarnings ("SameParameterValue" )
232
238
private ConvolutionLayer conv5x5 (String name , int out , int [] stride , int [] pad , double bias ) {
233
239
return new ConvolutionLayer .Builder (new int []{5 ,5 }, stride , pad ).name (name ).nOut (out ).biasInit (bias ).build ();
234
240
}
@@ -237,21 +243,21 @@ private SubsamplingLayer maxPool(String name, int[] kernel) {
237
243
return new SubsamplingLayer .Builder (kernel , new int []{2 ,2 }).name (name ).build ();
238
244
}
239
245
246
+ @ SuppressWarnings ("SameParameterValue" )
240
247
private DenseLayer fullyConnected (String name , int out , double bias , double dropOut , Distribution dist ) {
241
- return new DenseLayer .Builder ().name (name ).nOut (out ).biasInit (bias ).dropOut (dropOut ).dist ( dist ).build ();
248
+ return new DenseLayer .Builder ().name (name ).nOut (out ).biasInit (bias ).dropOut (dropOut ).weightInit ( new WeightInitDistribution ( dist ) ).build ();
242
249
}
243
250
244
- public MultiLayerNetwork lenetModel () {
245
- /**
251
+ private MultiLayerNetwork lenetModel () {
252
+ /*
246
253
* Revisde Lenet Model approach developed by ramgo2 achieves slightly above random
247
254
* Reference: https://gist.github.com/ramgo2/833f12e92359a2da9e5c2fb6333351c5
248
- ** /
255
+ */
249
256
MultiLayerConfiguration conf = new NeuralNetConfiguration .Builder ()
250
257
.seed (seed )
251
258
.l2 (0.005 )
252
259
.activation (Activation .RELU )
253
260
.weightInit (WeightInit .XAVIER )
254
- // .updater(new Nadam(1e-4))
255
261
.updater (new AdaDelta ())
256
262
.list ()
257
263
.layer (0 , convInit ("cnn1" , channels , 50 , new int []{5 , 5 }, new int []{1 , 1 }, new int []{0 , 0 }, 0 ))
@@ -270,12 +276,12 @@ public MultiLayerNetwork lenetModel() {
270
276
271
277
}
272
278
273
- public MultiLayerNetwork alexnetModel () {
274
- /**
279
+ private MultiLayerNetwork alexnetModel () {
280
+ /*
275
281
* AlexNet model interpretation based on the original paper ImageNet Classification with Deep Convolutional Neural Networks
276
282
* and the imagenetExample code referenced.
277
283
* http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf
278
- ** /
284
+ */
279
285
280
286
double nonZeroBias = 1 ;
281
287
double dropOut = 0.5 ;
@@ -298,8 +304,8 @@ public MultiLayerNetwork alexnetModel() {
298
304
.layer (conv3x3 ("cnn4" , 384 , nonZeroBias ))
299
305
.layer (conv3x3 ("cnn5" , 256 , nonZeroBias ))
300
306
.layer (maxPool ("maxpool3" , new int []{3 ,3 }))
301
- .layer (fullyConnected ("ffn1" , 4096 , nonZeroBias , dropOut , new GaussianDistribution (0 , 0.005 )))
302
- .layer (fullyConnected ("ffn2" , 4096 , nonZeroBias , dropOut , new GaussianDistribution (0 , 0.005 )))
307
+ .layer (fullyConnected ("ffn1" , 4096 , nonZeroBias , dropOut , new NormalDistribution (0 , 0.005 )))
308
+ .layer (fullyConnected ("ffn2" , 4096 , nonZeroBias , dropOut , new NormalDistribution (0 , 0.005 )))
303
309
.layer (new OutputLayer .Builder (LossFunctions .LossFunction .NEGATIVELOGLIKELIHOOD )
304
310
.name ("output" )
305
311
.nOut (numLabels )
@@ -312,15 +318,15 @@ public MultiLayerNetwork alexnetModel() {
312
318
313
319
}
314
320
315
- public static MultiLayerNetwork customModel () {
316
- /**
321
+ private static MultiLayerNetwork customModel () {
322
+ /*
317
323
* Use this method to build your own custom model.
318
- ** /
324
+ */
319
325
return null ;
320
326
}
321
327
322
328
public static void main (String [] args ) throws Exception {
323
- new AnimalsClassification ().run (args );
329
+ new AnimalsClassification ().run ();
324
330
}
325
331
326
332
}
0 commit comments