Skip to content

Commit 306dbe9

Browse files
committed
complete dataexamples refactoring.
Signed-off-by: Robert Altena <[email protected]>
1 parent 861316d commit 306dbe9

File tree

7 files changed

+36
-44
lines changed

7 files changed

+36
-44
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/dataexamples/MnistImagePipelineExampleAddNeuralNet.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -22,7 +22,6 @@
2222
import org.datavec.image.loader.NativeImageLoader;
2323
import org.datavec.image.recordreader.ImageRecordReader;
2424
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
25-
import org.deeplearning4j.eval.Evaluation;
2625
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
2726
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
2827
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -32,6 +31,7 @@
3231
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
3332
import org.deeplearning4j.nn.weights.WeightInit;
3433
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
34+
import org.nd4j.evaluation.classification.Evaluation;
3535
import org.nd4j.linalg.activations.Activation;
3636
import org.nd4j.linalg.api.ndarray.INDArray;
3737
import org.nd4j.linalg.dataset.DataSet;
@@ -64,12 +64,10 @@
6464
public class MnistImagePipelineExampleAddNeuralNet {
6565
private static Logger log = LoggerFactory.getLogger(MnistImagePipelineExampleAddNeuralNet.class);
6666

67-
/** Data URL for downloading */
68-
public static final String DATA_URL = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";
69-
7067
/** Location to save and extract the training/testing data */
71-
public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_Mnist/");
68+
private static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_Mnist/");
7269

70+
@SuppressWarnings("DuplicatedCode")
7371
public static void main(String[] args) throws Exception {
7472
// image information
7573
// 28 * 28 grayscale

dl4j-examples/src/main/java/org/deeplearning4j/examples/dataexamples/MnistImagePipelineExampleLoad.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -22,8 +22,8 @@
2222
import org.datavec.image.loader.NativeImageLoader;
2323
import org.datavec.image.recordreader.ImageRecordReader;
2424
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
25-
import org.deeplearning4j.eval.Evaluation;
2625
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
26+
import org.nd4j.evaluation.classification.Evaluation;
2727
import org.nd4j.linalg.api.ndarray.INDArray;
2828
import org.nd4j.linalg.dataset.api.DataSet;
2929
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@@ -53,12 +53,10 @@
5353
public class MnistImagePipelineExampleLoad {
5454
private static Logger log = LoggerFactory.getLogger(MnistImagePipelineExampleLoad.class);
5555

56-
/** Data URL for downloading */
57-
public static final String DATA_URL = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";
58-
5956
/** Location to save and extract the training/testing data */
60-
public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_Mnist/");
57+
private static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_Mnist/");
6158

59+
@SuppressWarnings("DuplicatedCode")
6260
public static void main(String[] args) throws Exception {
6361
// image information
6462
// 28 * 28 grayscale

dl4j-examples/src/main/java/org/deeplearning4j/examples/dataexamples/MnistImagePipelineExampleSave.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -61,12 +61,10 @@
6161
public class MnistImagePipelineExampleSave {
6262
private static Logger log = LoggerFactory.getLogger(MnistImagePipelineExampleSave.class);
6363

64-
/** Data URL for downloading */
65-
public static final String DATA_URL = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";
66-
6764
/** Location to save and extract the training/testing data */
68-
public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_Mnist/");
65+
private static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_Mnist/");
6966

67+
@SuppressWarnings("DuplicatedCode")
7068
public static void main(String[] args) throws Exception {
7169
// image information
7270
// 28 * 28 grayscale
@@ -154,7 +152,8 @@ This class downloadData() downloads the data
154152
boolean saveUpdater = false;
155153

156154
// ModelSerializer needs modelname, saveUpdater, Location
157-
model.save(locationToSave, saveUpdater);
155+
//noinspection ConstantConditions
156+
model.save(locationToSave, saveUpdater);
158157
}
159158

160159
}

dl4j-examples/src/main/java/org/deeplearning4j/examples/dataexamples/MnistImagePipelineLoadChooser.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -29,6 +29,7 @@
2929
import java.io.File;
3030
import java.util.Arrays;
3131
import java.util.List;
32+
import java.util.Objects;
3233

3334
/**
3435
* This code example is featured in this youtube video
@@ -54,26 +55,26 @@ public class MnistImagePipelineLoadChooser {
5455
private static Logger log = LoggerFactory.getLogger(MnistImagePipelineLoadChooser.class);
5556

5657
/** Location to save and extract the training/testing data */
57-
public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_Mnist/");
58+
private static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_Mnist/");
5859

5960
/*
6061
Create a popup window to allow you to chose an image file to test against the
6162
trained Neural Network
6263
Chosen images will be automatically
6364
scaled to 28*28 grayscale
6465
*/
65-
public static String fileChose() {
66+
private static String fileChose() {
6667
JFileChooser fc = new JFileChooser();
6768
int ret = fc.showOpenDialog(null);
6869
if (ret == JFileChooser.APPROVE_OPTION) {
6970
File file = fc.getSelectedFile();
70-
String filename = file.getAbsolutePath();
71-
return filename;
71+
return file.getAbsolutePath();
7272
} else {
7373
return null;
7474
}
7575
}
7676

77+
@SuppressWarnings("DuplicatedCode")
7778
public static void main(String[] args) throws Exception {
7879
int height = 28;
7980
int width = 28;
@@ -86,7 +87,7 @@ public static void main(String[] args) throws Exception {
8687
List<Integer> labelList = Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
8788

8889
// pop up file chooser
89-
String filechose = fileChose().toString();
90+
String filechose = fileChose();
9091

9192
//LOAD NEURAL NETWORK
9293

@@ -105,7 +106,7 @@ public static void main(String[] args) throws Exception {
105106

106107
log.info("TEST YOUR IMAGE AGAINST SAVED NETWORK");
107108
// FileChose is a string we will need a file
108-
File file = new File(filechose);
109+
File file = new File(Objects.requireNonNull(filechose));
109110

110111
// Use NativeImageLoader to convert to numerical matrix
111112
NativeImageLoader loader = new NativeImageLoader(height, width, channels);

dl4j-examples/src/main/java/org/deeplearning4j/examples/dataexamples/MultiClassLogit.java

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -82,8 +82,7 @@ private static DataSet getIrisDataSet() throws Exception {
8282
.map(mapRowToDataSet)
8383
.collect(Collectors.toList());
8484

85-
if (reader != null)
86-
reader.close();
85+
reader.close();
8786

8887
DataSetIterator iter = new IteratorDataSetIterator(data.iterator(), 150);
8988
irisDataSet = iter.next();
@@ -116,8 +115,8 @@ private static DataSet getIrisDataSet() throws Exception {
116115
Nd4j.create(Arrays.copyOfRange(parsedRows, columns - 1, columns)));
117116
};
118117

119-
public static INDArray trainModel(DataSet trainDataSet, long maxIterations, double learningRate,
120-
double minLearningRate) {
118+
private static INDArray trainModel(DataSet trainDataSet, long maxIterations, double learningRate,
119+
double minLearningRate) {
121120
log.info("Training the model...");
122121
long start = System.currentTimeMillis();
123122
INDArray trainFeatures = prependConstant(trainDataSet);
@@ -139,7 +138,7 @@ public static INDArray trainModel(DataSet trainDataSet, long maxIterations, doub
139138
return finalModel;
140139
}
141140

142-
public static void testModel(DataSet testDataSet, INDArray params) {
141+
private static void testModel(DataSet testDataSet, INDArray params) {
143142
log.info("Testing the model...");
144143
INDArray testFeatures = prependConstant(testDataSet);
145144
INDArray testLabels = testDataSet.getLabels();
@@ -160,11 +159,10 @@ public static void testModel(DataSet testDataSet, INDArray params) {
160159
* @param dataset dataset
161160
* @return features
162161
*/
163-
public static INDArray prependConstant(DataSet dataset) {
164-
INDArray features = Nd4j.hstack(
165-
Nd4j.ones(dataset.getFeatures().size(0), 1),
166-
dataset.getFeatures());
167-
return features;
162+
private static INDArray prependConstant(DataSet dataset) {
163+
return Nd4j.hstack(
164+
Nd4j.ones(dataset.getFeatures().size(0), 1),
165+
dataset.getFeatures());
168166
}
169167

170168
/**
@@ -241,7 +239,7 @@ private static INDArray training(INDArray x, INDArray y, long maxIterations, dou
241239
INDArray params = Nd4j.rand((int)x.size(1), 1); //random guess
242240

243241
INDArray newParams = params.dup();
244-
INDArray optimalParams = params.dup();
242+
INDArray optimalParams;
245243

246244
for (int i = 0; i < maxIterations; i++) {
247245
INDArray gradients = gradient(x, y, params);
@@ -288,8 +286,7 @@ private static INDArray getClassLabels(INDArray labels, double label) {
288286
* @return predicted labels
289287
*/
290288
private static INDArray predictLabels(INDArray features, INDArray params) {
291-
INDArray predictions = features.mmul(params).argMax(1);
292-
return predictions;
289+
return features.mmul(params).argMax(1);
293290
}
294291

295292
private static double countCorrectPred(INDArray labels, INDArray predictions) {

dl4j-examples/src/main/java/org/deeplearning4j/examples/dataexamples/PreprocessNormalizerExample.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -116,7 +116,6 @@ public static void main(String[] args) throws Exception {
116116
preProcessorIter.transform(firstBatch);
117117
log.info("\n{}",firstBatch);
118118
log.info("Note that this now gives the same results");
119-
break;
120119
}
121120

122121
log.info("If you are using batches and an iterator, set the preprocessor on your iterator to transform data automatically when next is called");

dl4j-examples/src/main/java/org/deeplearning4j/examples/dataexamples/SVMLightExample.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -20,7 +20,6 @@
2020
import org.datavec.api.records.reader.impl.misc.SVMLightRecordReader;
2121
import org.datavec.api.split.FileSplit;
2222
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
23-
import org.deeplearning4j.eval.Evaluation;
2423
import org.deeplearning4j.examples.download.DownloaderUtility;
2524
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
2625
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -30,6 +29,7 @@
3029
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
3130
import org.deeplearning4j.nn.weights.WeightInit;
3231
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
32+
import org.nd4j.evaluation.classification.Evaluation;
3333
import org.nd4j.linalg.activations.Activation;
3434
import org.nd4j.linalg.api.ndarray.INDArray;
3535
import org.nd4j.linalg.dataset.DataSet;
@@ -76,7 +76,7 @@ public static void main(String[] args) throws Exception {
7676
log.info("Build model....");
7777
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
7878
.seed(seed)
79-
.trainingWorkspaceMode(WorkspaceMode.SEPARATE)
79+
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
8080
.activation(Activation.RELU)
8181
.weightInit(WeightInit.XAVIER)
8282
.updater(Adam.builder().learningRate(0.02).beta1(0.9).beta2(0.999).build())

0 commit comments

Comments
 (0)