Skip to content

Commit aa55533

Browse files
committed
transferlearning
Signed-off-by: Robert Altena <[email protected]>
1 parent 663f6e2 commit aa55533

File tree

7 files changed

+22
-32
lines changed

7 files changed

+22
-32
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/transferlearning/vgg16/EditAtBottleneckOthersFrozen.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -16,7 +16,6 @@
1616

1717
package org.deeplearning4j.examples.transferlearning.vgg16;
1818

19-
import org.deeplearning4j.eval.Evaluation;
2019
import org.deeplearning4j.examples.transferlearning.vgg16.dataHelpers.FlowerDataSetIterator;
2120
import org.deeplearning4j.nn.conf.layers.DenseLayer;
2221
import org.deeplearning4j.nn.conf.layers.OutputLayer;
@@ -26,6 +25,7 @@
2625
import org.deeplearning4j.nn.weights.WeightInit;
2726
import org.deeplearning4j.zoo.ZooModel;
2827
import org.deeplearning4j.zoo.model.VGG16;
28+
import org.nd4j.evaluation.classification.Evaluation;
2929
import org.nd4j.linalg.activations.Activation;
3030
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
3131
import org.nd4j.linalg.learning.config.Nesterovs;

dl4j-examples/src/main/java/org/deeplearning4j/examples/transferlearning/vgg16/EditLastLayerOthersFrozen.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -16,18 +16,15 @@
1616

1717
package org.deeplearning4j.examples.transferlearning.vgg16;
1818

19-
import org.deeplearning4j.eval.Evaluation;
2019
import org.deeplearning4j.examples.transferlearning.vgg16.dataHelpers.FlowerDataSetIterator;
2120
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
2221
import org.deeplearning4j.nn.conf.layers.OutputLayer;
2322
import org.deeplearning4j.nn.graph.ComputationGraph;
24-
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
25-
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
2623
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
2724
import org.deeplearning4j.nn.transferlearning.TransferLearning;
28-
import org.deeplearning4j.nn.weights.WeightInit;
2925
import org.deeplearning4j.zoo.ZooModel;
3026
import org.deeplearning4j.zoo.model.VGG16;
27+
import org.nd4j.evaluation.classification.Evaluation;
3128
import org.nd4j.linalg.activations.Activation;
3229
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
3330
import org.nd4j.linalg.learning.config.Nesterovs;
@@ -60,7 +57,7 @@ public class EditLastLayerOthersFrozen {
6057
private static final int batchSize = 15;
6158
private static final String featureExtractionLayer = "fc2";
6259

63-
public static void main(String [] args) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
60+
public static void main(String [] args) throws IOException {
6461

6562
//Import vgg
6663
//Note that the model imported does not have an output layer (check printed summary)
@@ -86,8 +83,7 @@ public static void main(String [] args) throws UnsupportedKerasConfigurationExce
8683
.addLayer("predictions",
8784
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
8885
.nIn(4096).nOut(numClasses)
89-
.weightInit(WeightInit.DISTRIBUTION)
90-
.dist(new NormalDistribution(0,0.2*(2.0/(4096+numClasses)))) //This weight init dist gave better results than Xavier
86+
.weightInit(new NormalDistribution(0,0.2*(2.0/(4096+numClasses)))) //This weight init dist gave better results than Xavier
9187
.activation(Activation.SOFTMAX).build(),
9288
"fc2")
9389
.build();

dl4j-examples/src/main/java/org/deeplearning4j/examples/transferlearning/vgg16/FineTuneFromBlockFour.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -16,12 +16,12 @@
1616

1717
package org.deeplearning4j.examples.transferlearning.vgg16;
1818

19-
import org.deeplearning4j.eval.Evaluation;
2019
import org.deeplearning4j.examples.transferlearning.vgg16.dataHelpers.FlowerDataSetIterator;
2120
import org.deeplearning4j.nn.graph.ComputationGraph;
2221
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
2322
import org.deeplearning4j.nn.transferlearning.TransferLearning;
2423
import org.deeplearning4j.util.ModelSerializer;
24+
import org.nd4j.evaluation.classification.Evaluation;
2525
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
2626
import org.nd4j.linalg.learning.config.Sgd;
2727
import org.slf4j.Logger;
@@ -41,6 +41,7 @@
4141
* Finetuning like this is usually done with a low learning rate and a simple SGD optimizer
4242
* @author susaneraly on 3/6/17.
4343
*/
44+
@SuppressWarnings("DuplicatedCode")
4445
public class FineTuneFromBlockFour {
4546
private static final Logger log = org.slf4j.LoggerFactory.getLogger(FineTuneFromBlockFour.class);
4647

dl4j-examples/src/main/java/org/deeplearning4j/examples/transferlearning/vgg16/FitFromFeaturized.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -16,7 +16,6 @@
1616

1717
package org.deeplearning4j.examples.transferlearning.vgg16;
1818

19-
import org.deeplearning4j.eval.Evaluation;
2019
import org.deeplearning4j.examples.transferlearning.vgg16.dataHelpers.FeaturizedPreSave;
2120
import org.deeplearning4j.examples.transferlearning.vgg16.dataHelpers.FlowerDataSetIteratorFeaturized;
2221
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
@@ -27,9 +26,9 @@
2726
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
2827
import org.deeplearning4j.nn.transferlearning.TransferLearning;
2928
import org.deeplearning4j.nn.transferlearning.TransferLearningHelper;
30-
import org.deeplearning4j.nn.weights.WeightInit;
3129
import org.deeplearning4j.zoo.ZooModel;
3230
import org.deeplearning4j.zoo.model.VGG16;
31+
import org.nd4j.evaluation.classification.Evaluation;
3332
import org.nd4j.linalg.activations.Activation;
3433
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
3534
import org.nd4j.linalg.learning.config.Nesterovs;
@@ -51,10 +50,11 @@
5150
* Since the helper avoids the forward pass through the frozen layers we save on computation time when running multiple epochs.
5251
* In this manner, users can iterate quickly tweaking learning rates, weight initialization etc` to settle on a model that gives good results.
5352
*/
53+
@SuppressWarnings("DuplicatedCode")
5454
public class FitFromFeaturized {
5555
private static final Logger log = org.slf4j.LoggerFactory.getLogger(FitFromFeaturized.class);
5656

57-
public static final String featureExtractionLayer = FeaturizedPreSave.featurizeExtractionLayer;
57+
private static final String featureExtractionLayer = FeaturizedPreSave.featurizeExtractionLayer;
5858
protected static final long seed = 12345;
5959
protected static final int numClasses = 5;
6060
protected static final int nEpochs = 3;
@@ -85,8 +85,7 @@ public static void main(String [] args) throws IOException, InvalidKerasConfigur
8585
.addLayer("predictions",
8686
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
8787
.nIn(4096).nOut(numClasses)
88-
.weightInit(WeightInit.DISTRIBUTION)
89-
.dist(new NormalDistribution(0,0.2*(2.0/(4096+numClasses)))) //This weight init dist gave better results than Xavier
88+
.weightInit(new NormalDistribution(0,0.2*(2.0/(4096+numClasses)))) //This weight init dist gave better results than Xavier
9089
.activation(Activation.SOFTMAX).build(),
9190
"fc2")
9291
.build();

dl4j-examples/src/main/java/org/deeplearning4j/examples/transferlearning/vgg16/dataHelpers/FeaturizedPreSave.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -17,8 +17,6 @@
1717
package org.deeplearning4j.examples.transferlearning.vgg16.dataHelpers;
1818

1919
import org.deeplearning4j.nn.graph.ComputationGraph;
20-
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
21-
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
2220
import org.deeplearning4j.nn.transferlearning.TransferLearningHelper;
2321
import org.deeplearning4j.zoo.ZooModel;
2422
import org.deeplearning4j.zoo.model.VGG16;
@@ -42,7 +40,7 @@ public class FeaturizedPreSave {
4240
protected static final int batchSize = 15;
4341
public static final String featurizeExtractionLayer = "fc2";
4442

45-
public static void main(String [] args) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
43+
public static void main(String [] args) throws IOException {
4644

4745
//import org.deeplearning4j.transferlearning.vgg16 and print summary
4846
log.info("\n\nLoading org.deeplearning4j.transferlearning.vgg16...\n\n");
@@ -76,7 +74,7 @@ public static void main(String [] args) throws UnsupportedKerasConfigurationExce
7674
log.info("Finished pre saving featurized test and train data");
7775
}
7876

79-
public static void saveToDisk(DataSet currentFeaturized, int iterNum, boolean isTrain) {
77+
private static void saveToDisk(DataSet currentFeaturized, int iterNum, boolean isTrain) {
8078
File fileFolder = isTrain ? new File("trainFolder"): new File("testFolder");
8179
if (iterNum == 0) {
8280
fileFolder.mkdirs();

dl4j-examples/src/main/java/org/deeplearning4j/examples/transferlearning/vgg16/dataHelpers/FlowerDataSetIterator.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -69,7 +69,7 @@ public static DataSetIterator testIterator() throws IOException {
6969

7070
}
7171

72-
public static void setup(int batchSizeArg, int trainPerc) throws IOException {
72+
public static void setup(int batchSizeArg, int trainPerc) {
7373
try {
7474
downloadAndUntar();
7575
} catch (IOException e) {
@@ -96,7 +96,7 @@ private static DataSetIterator makeIterator(InputSplit split) throws IOException
9696
return iter;
9797
}
9898

99-
public static void downloadAndUntar() throws IOException {
99+
private static void downloadAndUntar() throws IOException {
100100
File rootFile = new File(DATA_DIR);
101101
if (!rootFile.exists()) {
102102
rootFile.mkdir();

dl4j-examples/src/main/java/org/deeplearning4j/examples/transferlearning/vgg16/dataHelpers/FlowerDataSetIteratorFeaturized.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
package org.deeplearning4j.examples.transferlearning.vgg16.dataHelpers;
1818

19-
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
2019
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
2120
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
21+
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
2222
import org.nd4j.linalg.dataset.ExistingMiniBatchDataSetIterator;
2323
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
2424
import org.slf4j.Logger;
@@ -33,11 +33,7 @@
3333
public class FlowerDataSetIteratorFeaturized {
3434
private static final Logger log = org.slf4j.LoggerFactory.getLogger(FlowerDataSetIteratorFeaturized.class);
3535

36-
static String featureExtractorLayer = FeaturizedPreSave.featurizeExtractionLayer;
37-
38-
public static void setup(String featureExtractorLayerArg) {
39-
featureExtractorLayer = featureExtractorLayerArg;
40-
}
36+
private static String featureExtractorLayer = FeaturizedPreSave.featurizeExtractionLayer;
4137

4238
public static DataSetIterator trainIterator() throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
4339
runFeaturize();

0 commit comments

Comments
 (0)