Skip to content

Commit 7bda08a

Browse files
author
Ryan Nett
committed
Add transfer learning and tf import examples
Signed-off-by: Ryan Nett <[email protected]>
1 parent 7823f81 commit 7bda08a

File tree

2 files changed

+379
-0
lines changed

2 files changed

+379
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package org.deeplearning4j.examples.samediff.tfimport;
2+
3+
import java.io.File;
4+
import java.io.IOException;
5+
import java.net.URL;
6+
import java.util.Arrays;
7+
import org.apache.commons.io.FilenameUtils;
8+
import org.datavec.image.loader.ImageLoader;
9+
import org.deeplearning4j.zoo.model.helper.InceptionResNetHelper;
10+
import org.deeplearning4j.zoo.util.imagenet.ImageNetLabels;
11+
import org.nd4j.autodiff.samediff.SameDiff;
12+
import org.nd4j.linalg.api.ndarray.INDArray;
13+
import org.nd4j.linalg.api.ops.DynamicCustomOp;
14+
import org.nd4j.linalg.factory.Nd4j;
15+
import org.nd4j.linalg.indexing.INDArrayIndex;
16+
import org.nd4j.linalg.indexing.NDArrayIndex;
17+
import org.nd4j.resources.Downloader;
18+
19+
/**
20+
* This example shows the ability to import and use Tensorflow models, specifically mobilenet, and use them for inference.
21+
*/
22+
public class SameDiffTFImportMobileNetExample {
23+
24+
public static String MODEL_URL = "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz";
25+
26+
// download and extract the model file in the ~/dl4j-examples-data directory used by other examples
27+
public static File downloadModel() throws Exception{
28+
String dataDir = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/tf_resnet");
29+
String modelFile = FilenameUtils.concat(dataDir, "mobilenet_v2_1.0_224.tgz");
30+
31+
File frozenFile = new File(FilenameUtils.concat(dataDir, "mobilenet_v2_1.0_224_frozen.pb"));
32+
33+
if(frozenFile.exists()){
34+
return frozenFile;
35+
}
36+
37+
Downloader.downloadAndExtract("tf_resnet", new URL(MODEL_URL), new File(modelFile), new File(dataDir), "519bba7052fd279c66d2a28dc3f51f46", 5);
38+
39+
return frozenFile;
40+
}
41+
42+
// gets the image we use to test the network.
43+
// This isn't a single class ImageNet image, so it won't do very well, but it will at least classify it as a dog or a cat.
44+
public static INDArray getTestImage() throws IOException {
45+
URL url = new URL("https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/img/image2.jpg?raw=true");
46+
return new ImageLoader(358, 500, 3).asMatrix(url.openStream());
47+
}
48+
49+
/**
50+
* Does inception preprocessing. Takes an image with shape [c, h, w]
51+
* and returns an image with shape [1, height, width, c].
52+
*
53+
* @param height the height to resize to
54+
* @param width the width to resize to
55+
*/
56+
public static INDArray inceptionPreprocessing(INDArray img, int height, int width){
57+
// add batch dimension
58+
img = Nd4j.expandDims(img, 0);
59+
60+
// change to channels-last
61+
img = img.permute(0, 2, 3, 1);
62+
63+
// normalize to 0-1
64+
img = img.div(256);
65+
66+
// resize
67+
INDArray preprocessedImage = Nd4j.createUninitialized(1, height, width, 3);
68+
69+
DynamicCustomOp op = DynamicCustomOp.builder("resize_bilinear")
70+
.addInputs(img)
71+
.addOutputs(preprocessedImage)
72+
.addIntegerArguments(height, width).build();
73+
Nd4j.exec(op);
74+
75+
// finish preprocessing
76+
preprocessedImage = preprocessedImage.sub(0.5);
77+
preprocessedImage = preprocessedImage.mul(2);
78+
return preprocessedImage;
79+
}
80+
81+
public static void main(String[] args) throws Exception {
82+
83+
// download and extract a tensorflow frozen model file (usually a .pb file)
84+
File modelFile = downloadModel();
85+
86+
// import the frozen model into a SameDiff instance
87+
SameDiff sd = SameDiff.importFrozenTF(modelFile);
88+
89+
System.out.println(sd.summary());
90+
91+
System.out.println("\n\n");
92+
93+
// get the image from https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/img/image2.jpg for testing
94+
INDArray testImage = getTestImage();
95+
96+
// preprocess image with inception preprocessing
97+
INDArray preprocessedImage = inceptionPreprocessing(testImage, 224, 224);
98+
99+
// Input and output names are found by looking at sd.summary() (printed earlyer).
100+
// The input variable is the output of no ops, and the output variable is the input of no ops.
101+
102+
// Alternatively, you can use sd.outputs() and sd.inputs().
103+
104+
System.out.println("Input: " + sd.inputs());
105+
System.out.println("Output: " + sd.outputs());
106+
107+
// Do inference for a single batch.
108+
INDArray out = sd.batchOutput()
109+
.input("input", preprocessedImage)
110+
.output("MobilenetV2/Predictions/Reshape_1")
111+
.execSingle();
112+
113+
// ignore label 0 (the background label)
114+
out = out.get(NDArrayIndex.all(), NDArrayIndex.interval(1, 1001));
115+
116+
// get the readable label for the classes
117+
String label = new ImageNetLabels().decodePredictions(out);
118+
119+
System.out.println("Predictions: " + label);
120+
121+
}
122+
123+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
package org.deeplearning4j.examples.samediff.tfimport;
2+
3+
import java.io.File;
4+
import java.util.Arrays;
5+
import java.util.Collections;
6+
import java.util.List;
7+
import org.deeplearning4j.datasets.fetchers.DataSetType;
8+
import org.deeplearning4j.datasets.iterator.impl.Cifar10DataSetIterator;
9+
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
10+
import org.deeplearning4j.examples.samediff.training.SameDiffMNISTTrainingExample;
11+
import org.nd4j.autodiff.listeners.At;
12+
import org.nd4j.autodiff.listeners.BaseListener;
13+
import org.nd4j.autodiff.listeners.Operation;
14+
import org.nd4j.autodiff.listeners.impl.ScoreListener;
15+
import org.nd4j.autodiff.listeners.records.History;
16+
import org.nd4j.autodiff.samediff.NameScope;
17+
import org.nd4j.autodiff.samediff.SDVariable;
18+
import org.nd4j.autodiff.samediff.SameDiff;
19+
import org.nd4j.autodiff.samediff.TrainingConfig;
20+
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
21+
import org.nd4j.autodiff.samediff.transform.GraphTransformUtil;
22+
import org.nd4j.autodiff.samediff.transform.OpPredicate;
23+
import org.nd4j.autodiff.samediff.transform.SubGraph;
24+
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
25+
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
26+
import org.nd4j.evaluation.classification.Evaluation;
27+
import org.nd4j.evaluation.classification.Evaluation.Metric;
28+
import org.nd4j.linalg.api.buffer.DataType;
29+
import org.nd4j.linalg.api.ndarray.INDArray;
30+
import org.nd4j.linalg.api.ops.DynamicCustomOp;
31+
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
32+
import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
33+
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D;
34+
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
35+
import org.nd4j.linalg.dataset.api.DataSet;
36+
import org.nd4j.linalg.dataset.api.MultiDataSet;
37+
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
38+
import org.nd4j.linalg.factory.Nd4j;
39+
import org.nd4j.linalg.learning.config.Adam;
40+
import org.nd4j.weightinit.impl.XavierInitScheme;
41+
42+
/**
43+
* This is an example of doing transfer learning by importing a tensorflow model of mobilenet and replacing the last layer.
44+
*
45+
* It turns the original imagenet model into a model for CIFAR 10.
46+
*
47+
* See {@link SameDiffTFImportMobileNetExample} for the model import example.
48+
* See {@link SameDiffMNISTTrainingExample} for the SameDiff training example.
49+
*
50+
*/
51+
public class SameDiffTransferLearningExample {
52+
53+
// Used to figure out the shapes of variables, needed to figure out how many channels are going into our added Conv layer
54+
static class ShapeListener extends BaseListener{
55+
56+
@Override
57+
public boolean isActive(Operation operation) {
58+
return true;
59+
}
60+
61+
@Override
62+
public void activationAvailable(SameDiff sd, At at,
63+
MultiDataSet batch, SameDiffOp op,
64+
String varName, INDArray activation) {
65+
System.out.println(varName + ": \t\t\t" + Arrays.toString(activation.shape()));
66+
67+
if(varName.endsWith("Shape")){
68+
System.out.println("Shape value: " + activation);
69+
}
70+
71+
}
72+
}
73+
74+
/**
75+
* Does inception preprocessing on a batch of images. Takes an image with shape [batchSize, c, h, w]
76+
* and returns an image with shape [batchSize, height, width, c].
77+
*
78+
* @param height the height to resize to
79+
* @param width the width to resize to
80+
*/
81+
public static INDArray batchInceptionPreprocessing(INDArray img, int height, int width){
82+
// change to channels-last
83+
img = img.permute(0, 2, 3, 1);
84+
85+
// normalize to 0-1
86+
img = img.div(256);
87+
88+
// resize
89+
INDArray preprocessedImage = Nd4j.createUninitialized(img.size(0), height, width, img.size(3));
90+
91+
DynamicCustomOp op = DynamicCustomOp.builder("resize_bilinear")
92+
.addInputs(img)
93+
.addOutputs(preprocessedImage)
94+
.addIntegerArguments(height, width).build();
95+
Nd4j.exec(op);
96+
97+
// finish preprocessing
98+
preprocessedImage = preprocessedImage.sub(0.5);
99+
preprocessedImage = preprocessedImage.mul(2);
100+
return preprocessedImage;
101+
}
102+
103+
public static void main(String[] args) throws Exception {
104+
File modelFile = SameDiffTFImportMobileNetExample.downloadModel();
105+
106+
// import the frozen model into a SameDiff instance
107+
SameDiff sd = SameDiff.importFrozenTF(modelFile);
108+
109+
System.out.println("\n\n------------------- Initial Graph -------------------");
110+
111+
System.out.println(sd.summary());
112+
113+
System.out.println("\n\n");
114+
115+
// Print shapes for each activation
116+
117+
// INDArray test = new Cifar10DataSetIterator(10).next().getFeatures();
118+
// test = batchInceptionPreprocessing(test, 224, 224);
119+
//
120+
// sd.batchOutput()
121+
// .input("input", test)
122+
// .output("MobilenetV2/Predictions/Reshape_1")
123+
// .listeners(new ShapeListener())
124+
// .execSingle();
125+
126+
// get info for the last convolution layer (MobilenetV2/Logits)
127+
Conv2D convOp = (Conv2D) sd.getOpById("MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D");
128+
System.out.println("Conv config: " + convOp.getConfig());
129+
130+
// replace last convolution layer (MobilenetV2/Logits)
131+
sd = GraphTransformUtil.replaceSubgraphsMatching(sd,
132+
SubGraphPredicate.withRoot(OpPredicate.nameMatches("MobilenetV2/Logits/Conv2d_1c_1x1/BiasAdd"))
133+
.withInputSubgraph(0, OpPredicate.nameMatches("MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D")),
134+
(sd1, subGraph) -> {
135+
136+
NameScope logits = sd1.withNameScope("Logits/Conv2D");
137+
138+
// get the output of the AveragePooling op
139+
SDVariable input = subGraph.inputs().get(1);
140+
141+
// we know the sizes from using the ShapeListener earlier
142+
143+
SDVariable w = sd1.var("W", new XavierInitScheme('c', 5 * 5 * 8, 10), DataType.FLOAT,
144+
1, 1, 1280, 10);
145+
146+
SDVariable b = sd1.var("b", new XavierInitScheme('c', 10 * 1280, 10 * 10), DataType.FLOAT,
147+
10);
148+
149+
// We know the needed config by getting and printing the convolution config earlier
150+
SDVariable output = sd1.cnn().conv2d(input, w, b, Conv2DConfig.builder()
151+
.kH(1).kW(1).isSameMode(true).dataFormat("NHWC").build());
152+
153+
logits.close();
154+
155+
return Collections.singletonList(output);
156+
});
157+
158+
// create SubGraphPredicate for selecting the MobilenetV2/Predictions ops
159+
SubGraphPredicate graphPred = SubGraphPredicate.withRoot(OpPredicate.nameEquals("MobilenetV2/Predictions/Reshape_1"))
160+
.withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameEquals("MobilenetV2/Predictions/Softmax"))
161+
.withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameEquals("MobilenetV2/Predictions/Reshape"))))
162+
.withInputSubgraph(1, SubGraphPredicate.withRoot(OpPredicate.nameEquals("MobilenetV2/Predictions/Shape")));
163+
164+
// replace the MobilenetV2/Predictions with our own softmax and loss
165+
sd = GraphTransformUtil.replaceSubgraphsMatching(sd,
166+
graphPred,
167+
(sd1, subGraph) -> {
168+
169+
// placeholder for labels (needed for training)
170+
SDVariable labels = sd1.placeHolder("label", DataType.FLOAT, -1, 10);
171+
172+
NameScope logits = sd1.withNameScope("Predictions");
173+
174+
// get the output of the preceding squeeze op
175+
SDVariable input = subGraph.inputs().get(0);
176+
177+
// dimension 1 by default
178+
SDVariable outputs = sd1.nn().softmax("Output", input);
179+
180+
// we need a loss to train on, the tensorflow model doesn't come with one
181+
SDVariable loss = sd1.loss().softmaxCrossEntropy("Loss", labels, input);
182+
183+
logits.close();
184+
185+
return Collections.emptyList();
186+
});
187+
188+
189+
// replace the input with input and inception preprocessing (except for resizing, which is done as part of the record reader)
190+
// can't do this with GraphTransformUtil as it can't replace variables or re-use ops
191+
192+
SDVariable input = sd.getVariable("input");
193+
194+
// change input to channels last (because this is a tensorflow import)
195+
SDVariable channelsLast = input.permute(0, 2, 3, 1);
196+
197+
// normalize to 0-1
198+
SDVariable normalized = channelsLast.div(256);
199+
200+
// change range to -1 - 1
201+
SDVariable processed = normalized.sub(0.5).mul(2);
202+
203+
sd.getOpById("MobilenetV2/Conv/Conv2D").replaceArg(0, processed);
204+
205+
206+
207+
System.out.println("\n\n------------------- Final Graph -------------------");
208+
209+
System.out.println(sd.summary());
210+
211+
SDVariable output = sd.getVariable("Predictions/Output");
212+
SDVariable loss = sd.getVariable("Predictions/Loss");
213+
214+
// we reshape to the proper size as part of the data set iterator, rather than doing it as part of the inception preprocessing
215+
INDArray test2 = new Cifar10DataSetIterator(10, new int[]{224, 224}, DataSetType.TRAIN, null, 12345).next().getFeatures();
216+
System.out.println("CIFAR10 Shape: " + Arrays.toString(test2.shape()));
217+
218+
// Test run
219+
sd.batchOutput()
220+
.input("input", test2)
221+
.output(output)
222+
// .listeners(new ShapeListener())
223+
.execSingle();
224+
225+
// need to set loss for training
226+
sd.setLossVariables(loss);
227+
228+
// the tensorflow model doesn't come with placeholder shapes, but we need to set them for training
229+
sd.getVariable("input").setShape(new long[]{-1, 3, 224, 224});
230+
231+
// Training. See SameDiffMNISTTrainingExample for more details
232+
double learningRate = 1e-3;
233+
TrainingConfig config = new TrainingConfig.Builder()
234+
.l2(1e-4) //L2 regularization
235+
.updater(new Adam(learningRate)) //Adam optimizer with specified learning rate
236+
.dataSetFeatureMapping("input") //DataSet features array should be associated with variable "input"
237+
.dataSetLabelMapping("label") //DataSet label array should be associated with variable "label"
238+
.trainEvaluation(output, 0, new Evaluation()) // add a training evaluation
239+
.build();
240+
241+
sd.setTrainingConfig(config);
242+
sd.addListeners(new ScoreListener(20));
243+
244+
// again, we reshape to the proper size as part of the data set iterator
245+
DataSetIterator trainData = new Cifar10DataSetIterator(32, new int[]{224, 224}, DataSetType.TRAIN, null, 12345);
246+
247+
//Perform fine tuning for 20 epochs. The pre-trained weights are imported as constants, and thus not trained
248+
int numEpochs = 20;
249+
History hist = sd.fit()
250+
.train(trainData, numEpochs)
251+
.exec();
252+
List<Double> acc = hist.trainingEval(Metric.ACCURACY);
253+
254+
System.out.println("Accuracy: " + acc);
255+
}
256+
}

0 commit comments

Comments
 (0)