Skip to content

Commit 8e1b23c

Browse files
committed
AnimalClassification.
Signed-off-by: Robert Altena <[email protected]>
1 parent 7e6ae4b commit 8e1b23c

File tree

1 file changed

+38
-32
lines changed

1 file changed

+38
-32
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/AnimalsClassification.java

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -34,13 +34,13 @@
3434
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
3535
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
3636
import org.deeplearning4j.nn.conf.distribution.Distribution;
37-
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
3837
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
3938
import org.deeplearning4j.nn.conf.inputs.InputType;
4039
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
4140
import org.deeplearning4j.nn.conf.layers.*;
4241
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
4342
import org.deeplearning4j.nn.weights.WeightInit;
43+
import org.deeplearning4j.nn.weights.WeightInitDistribution;
4444
import org.deeplearning4j.optimize.api.InvocationType;
4545
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
4646
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
@@ -61,6 +61,7 @@
6161
import java.io.File;
6262
import java.util.Arrays;
6363
import java.util.List;
64+
import java.util.Objects;
6465
import java.util.Random;
6566

6667
import static java.lang.Math.toIntExact;
@@ -92,43 +93,42 @@ public class AnimalsClassification {
9293
protected static long seed = 42;
9394
protected static Random rng = new Random(seed);
9495
protected static int epochs = 100;
95-
protected static double splitTrainTest = 0.8;
9696
protected static boolean save = false;
97-
protected static int maxPathsPerLabel=18;
9897

99-
protected static String modelType = "LeNet"; // LeNet, AlexNet or Custom but you need to fill it out
10098
private int numLabels;
10199

102100
public static String dataLocalPath;
103101

104-
public void run(String[] args) throws Exception {
102+
public void run() throws Exception {
105103

106104
dataLocalPath = DownloaderUtility.ANIMALS.Download();
107-
/**cd
105+
/* cd
108106
* Data Setup -> organize and limit data file paths:
109107
* - mainPath = path to image files
110108
* - fileSplit = define basic dataset split with limits on format
111109
* - pathFilter = define additional file load filter to limit size and balance batch content
112-
**/
110+
*/
113111
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
114112
File mainPath = new File(dataLocalPath);
115113
FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng);
116114
int numExamples = toIntExact(fileSplit.length());
117-
numLabels = fileSplit.getRootDir().listFiles(File::isDirectory).length; //This only works if your root is clean: only label subdirs.
115+
numLabels = Objects.requireNonNull(fileSplit.getRootDir().listFiles(File::isDirectory)).length; //This only works if your root is clean: only label subdirs.
116+
int maxPathsPerLabel = 18;
118117
BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, maxPathsPerLabel);
119118

120-
/**
119+
/*
121120
* Data Setup -> train test split
122121
* - inputSplit = define train and test split
123-
**/
122+
*/
123+
double splitTrainTest = 0.8;
124124
InputSplit[] inputSplit = fileSplit.sample(pathFilter, splitTrainTest, 1 - splitTrainTest);
125125
InputSplit trainData = inputSplit[0];
126126
InputSplit testData = inputSplit[1];
127127

128-
/**
128+
/*
129129
* Data Setup -> transformation
130-
* - Transform = how to tranform images and generate large dataset to train on
131-
**/
130+
* - Transform = how to transform images and generate large dataset to train on
131+
*/
132132
ImageTransform flipTransform1 = new FlipImageTransform(rng);
133133
ImageTransform flipTransform2 = new FlipImageTransform(new Random(123));
134134
ImageTransform warpTransform = new WarpImageTransform(rng, 42);
@@ -137,11 +137,12 @@ public void run(String[] args) throws Exception {
137137
new Pair<>(flipTransform2,0.8),
138138
new Pair<>(warpTransform,0.5));
139139

140+
//noinspection ConstantConditions
140141
ImageTransform transform = new PipelineImageTransform(pipeline,shuffle);
141-
/**
142+
/*
142143
* Data Setup -> normalization
143144
* - how to normalize images and generate large dataset to train on
144-
**/
145+
*/
145146
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
146147

147148
log.info("Build model....");
@@ -150,7 +151,10 @@ public void run(String[] args) throws Exception {
150151
// MultiLayerNetwork network = new AlexNet(height, width, channels, numLabels, seed, iterations).init();
151152

152153
MultiLayerNetwork network;
154+
// LeNet, AlexNet or Custom but you need to fill it out
155+
String modelType = "LeNet";
153156
switch (modelType) {
157+
//noinspection ConstantConditions
154158
case "LeNet":
155159
network = lenetModel();
156160
break;
@@ -168,12 +172,12 @@ public void run(String[] args) throws Exception {
168172
UIServer uiServer = UIServer.getInstance();
169173
StatsStorage statsStorage = new FileStatsStorage(new File(System.getProperty("java.io.tmpdir"), "ui-stats.dl4j"));
170174
uiServer.attach(statsStorage);
171-
/**
175+
/*
172176
* Data Setup -> define how to load data into net:
173177
* - recordReader = the reader that loads and converts image data pass in inputSplit to initialize
174178
* - dataIter = a generator that only loads one batch at a time into memory to save memory
175179
* - trainIter = uses MultipleEpochsIterator to ensure model runs through the data for all epochs
176-
**/
180+
*/
177181
ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
178182
DataSetIterator trainIter;
179183

@@ -221,6 +225,7 @@ public void run(String[] args) throws Exception {
221225
log.info("****************Example finished********************");
222226
}
223227

228+
@SuppressWarnings("SameParameterValue")
224229
private ConvolutionLayer convInit(String name, int in, int out, int[] kernel, int[] stride, int[] pad, double bias) {
225230
return new ConvolutionLayer.Builder(kernel, stride, pad).name(name).nIn(in).nOut(out).biasInit(bias).build();
226231
}
@@ -229,6 +234,7 @@ private ConvolutionLayer conv3x3(String name, int out, double bias) {
229234
return new ConvolutionLayer.Builder(new int[]{3,3}, new int[] {1,1}, new int[] {1,1}).name(name).nOut(out).biasInit(bias).build();
230235
}
231236

237+
@SuppressWarnings("SameParameterValue")
232238
private ConvolutionLayer conv5x5(String name, int out, int[] stride, int[] pad, double bias) {
233239
return new ConvolutionLayer.Builder(new int[]{5,5}, stride, pad).name(name).nOut(out).biasInit(bias).build();
234240
}
@@ -237,21 +243,21 @@ private SubsamplingLayer maxPool(String name, int[] kernel) {
237243
return new SubsamplingLayer.Builder(kernel, new int[]{2,2}).name(name).build();
238244
}
239245

246+
@SuppressWarnings("SameParameterValue")
240247
private DenseLayer fullyConnected(String name, int out, double bias, double dropOut, Distribution dist) {
241-
return new DenseLayer.Builder().name(name).nOut(out).biasInit(bias).dropOut(dropOut).dist(dist).build();
248+
return new DenseLayer.Builder().name(name).nOut(out).biasInit(bias).dropOut(dropOut).weightInit(new WeightInitDistribution(dist)).build();
242249
}
243250

244-
public MultiLayerNetwork lenetModel() {
245-
/**
251+
private MultiLayerNetwork lenetModel() {
252+
/*
246253
* Revisde Lenet Model approach developed by ramgo2 achieves slightly above random
247254
* Reference: https://gist.github.com/ramgo2/833f12e92359a2da9e5c2fb6333351c5
248-
**/
255+
*/
249256
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
250257
.seed(seed)
251258
.l2(0.005)
252259
.activation(Activation.RELU)
253260
.weightInit(WeightInit.XAVIER)
254-
// .updater(new Nadam(1e-4))
255261
.updater(new AdaDelta())
256262
.list()
257263
.layer(0, convInit("cnn1", channels, 50 , new int[]{5, 5}, new int[]{1, 1}, new int[]{0, 0}, 0))
@@ -270,12 +276,12 @@ public MultiLayerNetwork lenetModel() {
270276

271277
}
272278

273-
public MultiLayerNetwork alexnetModel() {
274-
/**
279+
private MultiLayerNetwork alexnetModel() {
280+
/*
275281
* AlexNet model interpretation based on the original paper ImageNet Classification with Deep Convolutional Neural Networks
276282
* and the imagenetExample code referenced.
277283
* http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf
278-
**/
284+
*/
279285

280286
double nonZeroBias = 1;
281287
double dropOut = 0.5;
@@ -298,8 +304,8 @@ public MultiLayerNetwork alexnetModel() {
298304
.layer(conv3x3("cnn4", 384, nonZeroBias))
299305
.layer(conv3x3("cnn5", 256, nonZeroBias))
300306
.layer(maxPool("maxpool3", new int[]{3,3}))
301-
.layer(fullyConnected("ffn1", 4096, nonZeroBias, dropOut, new GaussianDistribution(0, 0.005)))
302-
.layer(fullyConnected("ffn2", 4096, nonZeroBias, dropOut, new GaussianDistribution(0, 0.005)))
307+
.layer(fullyConnected("ffn1", 4096, nonZeroBias, dropOut, new NormalDistribution(0, 0.005)))
308+
.layer(fullyConnected("ffn2", 4096, nonZeroBias, dropOut, new NormalDistribution(0, 0.005)))
303309
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
304310
.name("output")
305311
.nOut(numLabels)
@@ -312,15 +318,15 @@ public MultiLayerNetwork alexnetModel() {
312318

313319
}
314320

315-
public static MultiLayerNetwork customModel() {
316-
/**
321+
private static MultiLayerNetwork customModel() {
322+
/*
317323
* Use this method to build your own custom model.
318-
**/
324+
*/
319325
return null;
320326
}
321327

322328
public static void main(String[] args) throws Exception {
323-
new AnimalsClassification().run(args);
329+
new AnimalsClassification().run();
324330
}
325331

326332
}

0 commit comments

Comments
 (0)