Skip to content

Commit 337e0d9

Browse files
author
Ryan Nett
committed
fixes
Signed-off-by: Ryan Nett <[email protected]>
1 parent f6d8703 commit 337e0d9

File tree

3 files changed

+227
-146
lines changed

3 files changed

+227
-146
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/samediff/tfimport/SameDiffTFImportMobileNetExample.java

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,50 @@
2121
*/
2222
public class SameDiffTFImportMobileNetExample {
2323

24+
public static void main(String[] args) throws Exception {
25+
26+
// download and extract a tensorflow frozen model file (usually a .pb file)
27+
File modelFile = downloadModel();
28+
29+
// import the frozen model into a SameDiff instance
30+
SameDiff sd = SameDiff.importFrozenTF(modelFile);
31+
32+
System.out.println(sd.summary());
33+
34+
System.out.println("\n\n");
35+
36+
// get the image from https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/img/image2.jpg for testing
37+
INDArray testImage = getTestImage();
38+
39+
// preprocess image with inception preprocessing
40+
INDArray preprocessedImage = inceptionPreprocessing(testImage, 224, 224);
41+
42+
// Input and output names are found by looking at sd.summary() (printed earlyer).
43+
// The input variable is the output of no ops, and the output variable is the input of no ops.
44+
45+
// Alternatively, you can use sd.outputs() and sd.inputs().
46+
47+
System.out.println("Input: " + sd.inputs());
48+
System.out.println("Output: " + sd.outputs());
49+
50+
// Do inference for a single batch.
51+
INDArray out = sd.batchOutput()
52+
.input("input", preprocessedImage)
53+
.output("MobilenetV2/Predictions/Reshape_1")
54+
.execSingle();
55+
56+
// ignore label 0 (the background label)
57+
out = out.get(NDArrayIndex.all(), NDArrayIndex.interval(1, 1001));
58+
59+
// get the readable label for the classes
60+
String label = new ImageNetLabels().decodePredictions(out);
61+
62+
System.out.println("Predictions: " + label);
63+
64+
}
65+
66+
67+
2468
public static String MODEL_URL = "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz";
2569

2670
// download and extract the model file in the ~/dl4j-examples-data directory used by other examples
@@ -50,6 +94,8 @@ public static INDArray getTestImage() throws IOException {
5094
* Does inception preprocessing. Takes an image with shape [c, h, w]
5195
* and returns an image with shape [1, height, width, c].
5296
*
97+
* Eventually this will be made part of DL4J.
98+
*
5399
* @param height the height to resize to
54100
* @param width the width to resize to
55101
*/
@@ -78,46 +124,4 @@ public static INDArray inceptionPreprocessing(INDArray img, int height, int widt
78124
return preprocessedImage;
79125
}
80126

81-
public static void main(String[] args) throws Exception {
82-
83-
// download and extract a tensorflow frozen model file (usually a .pb file)
84-
File modelFile = downloadModel();
85-
86-
// import the frozen model into a SameDiff instance
87-
SameDiff sd = SameDiff.importFrozenTF(modelFile);
88-
89-
System.out.println(sd.summary());
90-
91-
System.out.println("\n\n");
92-
93-
// get the image from https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/img/image2.jpg for testing
94-
INDArray testImage = getTestImage();
95-
96-
// preprocess image with inception preprocessing
97-
INDArray preprocessedImage = inceptionPreprocessing(testImage, 224, 224);
98-
99-
// Input and output names are found by looking at sd.summary() (printed earlyer).
100-
// The input variable is the output of no ops, and the output variable is the input of no ops.
101-
102-
// Alternatively, you can use sd.outputs() and sd.inputs().
103-
104-
System.out.println("Input: " + sd.inputs());
105-
System.out.println("Output: " + sd.outputs());
106-
107-
// Do inference for a single batch.
108-
INDArray out = sd.batchOutput()
109-
.input("input", preprocessedImage)
110-
.output("MobilenetV2/Predictions/Reshape_1")
111-
.execSingle();
112-
113-
// ignore label 0 (the background label)
114-
out = out.get(NDArrayIndex.all(), NDArrayIndex.interval(1, 1001));
115-
116-
// get the readable label for the classes
117-
String label = new ImageNetLabels().decodePredictions(out);
118-
119-
System.out.println("Predictions: " + label);
120-
121-
}
122-
123127
}

dl4j-examples/src/main/java/org/deeplearning4j/examples/samediff/tfimport/SameDiffTransferLearningExample.java

Lines changed: 134 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import org.deeplearning4j.datasets.fetchers.DataSetType;
88
import org.deeplearning4j.datasets.iterator.impl.Cifar10DataSetIterator;
99
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
10+
import org.deeplearning4j.examples.samediff.training.SameDiffCustomListenerExample;
1011
import org.deeplearning4j.examples.samediff.training.SameDiffMNISTTrainingExample;
1112
import org.nd4j.autodiff.listeners.At;
1213
import org.nd4j.autodiff.listeners.BaseListener;
@@ -42,64 +43,15 @@
4243
/**
4344
* This is an example of doing transfer learning by importing a tensorflow model of mobilenet and replacing the last layer.
4445
*
45-
* It turns the original imagenet model into a model for CIFAR 10.
46+
* It turns the original ImageNet model into a model for CIFAR 10.
4647
*
4748
* See {@link SameDiffTFImportMobileNetExample} for the model import example.
4849
* See {@link SameDiffMNISTTrainingExample} for the SameDiff training example.
50+
* See {@link SameDiffCustomListenerExample} for an example of how to use custom listeners (we use one here to find the shapes of an activation).
4951
*
5052
*/
5153
public class SameDiffTransferLearningExample {
5254

53-
// Used to figure out the shapes of variables, needed to figure out how many channels are going into our added Conv layer
54-
static class ShapeListener extends BaseListener{
55-
56-
@Override
57-
public boolean isActive(Operation operation) {
58-
return true;
59-
}
60-
61-
@Override
62-
public void activationAvailable(SameDiff sd, At at,
63-
MultiDataSet batch, SameDiffOp op,
64-
String varName, INDArray activation) {
65-
System.out.println(varName + ": \t\t\t" + Arrays.toString(activation.shape()));
66-
67-
if(varName.endsWith("Shape")){
68-
System.out.println("Shape value: " + activation);
69-
}
70-
71-
}
72-
}
73-
74-
/**
75-
* Does inception preprocessing on a batch of images. Takes an image with shape [batchSize, c, h, w]
76-
* and returns an image with shape [batchSize, height, width, c].
77-
*
78-
* @param height the height to resize to
79-
* @param width the width to resize to
80-
*/
81-
public static INDArray batchInceptionPreprocessing(INDArray img, int height, int width){
82-
// change to channels-last
83-
img = img.permute(0, 2, 3, 1);
84-
85-
// normalize to 0-1
86-
img = img.div(256);
87-
88-
// resize
89-
INDArray preprocessedImage = Nd4j.createUninitialized(img.size(0), height, width, img.size(3));
90-
91-
DynamicCustomOp op = DynamicCustomOp.builder("resize_bilinear")
92-
.addInputs(img)
93-
.addOutputs(preprocessedImage)
94-
.addIntegerArguments(height, width).build();
95-
Nd4j.exec(op);
96-
97-
// finish preprocessing
98-
preprocessedImage = preprocessedImage.sub(0.5);
99-
preprocessedImage = preprocessedImage.mul(2);
100-
return preprocessedImage;
101-
}
102-
10355
public static void main(String[] args) throws Exception {
10456
File modelFile = SameDiffTFImportMobileNetExample.downloadModel();
10557

@@ -112,7 +64,14 @@ public static void main(String[] args) throws Exception {
11264

11365
System.out.println("\n\n");
11466

115-
// Print shapes for each activation
67+
68+
// We want to replace the last convolution layer and the output layer with our own ops, so we can fine tune the network
69+
// These are the MobilenetV2/Logits and MobilenetV2/Predictions sections, respectively. See the printed summary.
70+
71+
72+
// Print shapes for each activation.
73+
// We need to know the shape (especially the channels) of the convolution op's input, so we know what shape to make the weight.
74+
// We use a custom listener for this, see SameDiffCustomListenerExample
11675

11776
// INDArray test = new Cifar10DataSetIterator(10).next().getFeatures();
11877
// test = batchInceptionPreprocessing(test, 224, 224);
@@ -123,23 +82,55 @@ public static void main(String[] args) throws Exception {
12382
// .listeners(new ShapeListener())
12483
// .execSingle();
12584

126-
// get info for the last convolution layer (MobilenetV2/Logits)
85+
// get info for the last convolution layer (MobilenetV2/Logits). We want to use an equivalent config.
12786
Conv2D convOp = (Conv2D) sd.getOpById("MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D");
12887
System.out.println("Conv config: " + convOp.getConfig());
12988

130-
// replace last convolution layer (MobilenetV2/Logits)
131-
sd = GraphTransformUtil.replaceSubgraphsMatching(sd,
89+
/*
90+
The MobilenetV2/Logits section looks like:
91+
MobilenetV2/Logits/AvgPool
92+
MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D
93+
MobilenetV2/Logits/Conv2d_1c_1x1/BiasAdd
94+
MobilenetV2/Logits/Squeeze
95+
96+
We want to replace the convolution layer (Conv2D and BiasAdd) with our own, so we can fine tune it.
97+
98+
99+
The SubGraphPredicate will select a subset of the graph by starting at the root node,
100+
and then optionally applying SubGraphPredicate's for inputs.
101+
Those SubGraphPredicate's can also add their inputs, etc.
102+
103+
The predicate will only accept a subgraph if it passes all the filters.
104+
*/
105+
106+
// Create a predicate for selecting the BiasAdd and Conv2D ops we want
107+
SubGraphPredicate pred1 =
108+
// Select the subgraph with root MobilenetV2/Logits/Conv2d_1c_1x1/BiasAdd
132109
SubGraphPredicate.withRoot(OpPredicate.nameMatches("MobilenetV2/Logits/Conv2d_1c_1x1/BiasAdd"))
133-
.withInputSubgraph(0, OpPredicate.nameMatches("MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D")),
134-
(sd1, subGraph) -> {
110+
// Select (and require) the BiasAdd's 0th input to be MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D
111+
.withInputSubgraph(0, OpPredicate.nameMatches("MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D"));
112+
113+
114+
/*
115+
Replace any subgraphs matching the predicate with our own subgraph
116+
There will only be one match, but you can use SubGraphPredicate and GraphTransformUtil to replace many occurrences of the same subgraph.
135117
118+
The number of outputs from the replacement subgraph must match the number of outputs of the subgraph it is replacing.
119+
120+
Note that the graph isn't actually modified, a copy is made, modified, and then returned.
121+
*/
122+
sd = GraphTransformUtil.replaceSubgraphsMatching(sd,
123+
pred1,
124+
(sd1, subGraph) -> {
136125
NameScope logits = sd1.withNameScope("Logits/Conv2D");
137126

138127
// get the output of the AveragePooling op
139128
SDVariable input = subGraph.inputs().get(1);
140129

141130
// we know the sizes from using the ShapeListener earlier
142131

132+
// We know what shape the weight needs to be from the input's channels and the config's kernel height and width.
133+
// This is why we printed the shapes.
143134
SDVariable w = sd1.var("W", new XavierInitScheme('c', 5 * 5 * 8, 10), DataType.FLOAT,
144135
1, 1, 1280, 10);
145136

@@ -155,17 +146,38 @@ public static void main(String[] args) throws Exception {
155146
return Collections.singletonList(output);
156147
});
157148

158-
// create SubGraphPredicate for selecting the MobilenetV2/Predictions ops
159-
SubGraphPredicate graphPred = SubGraphPredicate.withRoot(OpPredicate.nameEquals("MobilenetV2/Predictions/Reshape_1"))
160-
.withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameEquals("MobilenetV2/Predictions/Softmax"))
161-
.withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameEquals("MobilenetV2/Predictions/Reshape"))))
162-
.withInputSubgraph(1, SubGraphPredicate.withRoot(OpPredicate.nameEquals("MobilenetV2/Predictions/Shape")));
149+
/*
150+
The MobilenetV2/Predictions section looks like:
151+
MobilenetV2/Predictions/Reshape/shape
152+
MobilenetV2/Predictions/Reshape
153+
MobilenetV2/Predictions/Softmax
154+
MobilenetV2/Predictions/Shape
155+
MobilenetV2/Predictions/Reshape_1
163156
164-
// replace the MobilenetV2/Predictions with our own softmax and loss
157+
We want to replace the reshapes (unneeded and the wrong shape) and the softmax (we need a loss function and an output function).
158+
You could keep the softmax, but there is no reason to.
159+
160+
We also need to add a labels input.
161+
162+
Note that this subgraph has no outputs, so neither should the replacement subgraph.
163+
*/
164+
165+
// create SubGraphPredicate for selecting the MobilenetV2/Predictions ops
166+
SubGraphPredicate pred2 =
167+
// Select a subgraph starting with the Reshape_1 op
168+
SubGraphPredicate.withRoot(OpPredicate.nameEquals("MobilenetV2/Predictions/Reshape_1"))
169+
// Add the 0th input to the subgraph if it is the specified Softmax Op
170+
.withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameEquals("MobilenetV2/Predictions/Softmax"))
171+
// Add the 0th input of the Softmax op to the subgraph, as long as it is the specified Reshape op
172+
.withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameEquals("MobilenetV2/Predictions/Reshape"))))
173+
// Add the 1st input to the subgraph if it is the specified Shape Op
174+
.withInputSubgraph(1, SubGraphPredicate.withRoot(OpPredicate.nameEquals("MobilenetV2/Predictions/Shape")));
175+
176+
// Replace any subgraphs matching the predicate with our own subgraph
177+
// There will only be one match, but you can use SubGraphPredicate and GraphTransformUtil to replace many occurrences of the same subgraph
165178
sd = GraphTransformUtil.replaceSubgraphsMatching(sd,
166-
graphPred,
179+
pred2,
167180
(sd1, subGraph) -> {
168-
169181
// placeholder for labels (needed for training)
170182
SDVariable labels = sd1.placeHolder("label", DataType.FLOAT, -1, 10);
171183

@@ -186,8 +198,8 @@ public static void main(String[] args) throws Exception {
186198
});
187199

188200

189-
// replace the input with input and inception preprocessing (except for resizing, which is done as part of the record reader)
190-
// can't do this with GraphTransformUtil as it can't replace variables or re-use ops
201+
// Add inception preprocessing to the input (except for resizing, which is done as part of the record reader)
202+
// Can't do this with GraphTransformUtil as it can't replace variables or re-use ops
191203

192204
SDVariable input = sd.getVariable("input");
193205

@@ -200,6 +212,7 @@ public static void main(String[] args) throws Exception {
200212
// change range to -1 - 1
201213
SDVariable processed = normalized.sub(0.5).mul(2);
202214

215+
// The 0th arg was input, replace it with the preprocessed input
203216
sd.getOpById("MobilenetV2/Conv/Conv2D").replaceArg(0, processed);
204217

205218

@@ -254,4 +267,58 @@ public static void main(String[] args) throws Exception {
254267

255268
System.out.println("Accuracy: " + acc);
256269
}
270+
271+
/**
272+
* Used to figure out the shapes of variables, needed to figure out how many channels are going into our added Conv layer
273+
*
274+
* See {@link SameDiffCustomListenerExample}
275+
*/
276+
static class ShapeListener extends BaseListener{
277+
278+
@Override
279+
public boolean isActive(Operation operation) {
280+
return true;
281+
}
282+
283+
@Override
284+
public void activationAvailable(SameDiff sd, At at,
285+
MultiDataSet batch, SameDiffOp op,
286+
String varName, INDArray activation) {
287+
System.out.println(varName + ": \t\t\t" + Arrays.toString(activation.shape()));
288+
289+
if(varName.endsWith("Shape")){
290+
System.out.println("Shape value: " + activation);
291+
}
292+
293+
}
294+
}
295+
296+
/**
297+
* Does inception preprocessing on a batch of images. Takes an image with shape [batchSize, c, h, w]
298+
* and returns an image with shape [batchSize, height, width, c].
299+
*
300+
* @param height the height to resize to
301+
* @param width the width to resize to
302+
*/
303+
public static INDArray batchInceptionPreprocessing(INDArray img, int height, int width){
304+
// change to channels-last
305+
img = img.permute(0, 2, 3, 1);
306+
307+
// normalize to 0-1
308+
img = img.div(256);
309+
310+
// resize
311+
INDArray preprocessedImage = Nd4j.createUninitialized(img.size(0), height, width, img.size(3));
312+
313+
DynamicCustomOp op = DynamicCustomOp.builder("resize_bilinear")
314+
.addInputs(img)
315+
.addOutputs(preprocessedImage)
316+
.addIntegerArguments(height, width).build();
317+
Nd4j.exec(op);
318+
319+
// finish preprocessing
320+
preprocessedImage = preprocessedImage.sub(0.5);
321+
preprocessedImage = preprocessedImage.mul(2);
322+
return preprocessedImage;
323+
}
257324
}

0 commit comments

Comments
 (0)