7
7
import org .deeplearning4j .datasets .fetchers .DataSetType ;
8
8
import org .deeplearning4j .datasets .iterator .impl .Cifar10DataSetIterator ;
9
9
import org .deeplearning4j .datasets .iterator .impl .MnistDataSetIterator ;
10
+ import org .deeplearning4j .examples .samediff .training .SameDiffCustomListenerExample ;
10
11
import org .deeplearning4j .examples .samediff .training .SameDiffMNISTTrainingExample ;
11
12
import org .nd4j .autodiff .listeners .At ;
12
13
import org .nd4j .autodiff .listeners .BaseListener ;
42
43
/**
43
44
* This is an example of doing transfer learning by importing a tensorflow model of mobilenet and replacing the last layer.
44
45
*
45
- * It turns the original imagenet model into a model for CIFAR 10.
46
+ * It turns the original ImageNet model into a model for CIFAR 10.
46
47
*
47
48
* See {@link SameDiffTFImportMobileNetExample} for the model import example.
48
49
* See {@link SameDiffMNISTTrainingExample} for the SameDiff training example.
50
+ * See {@link SameDiffCustomListenerExample} for an example of how to use custom listeners (we use one here to find the shapes of an activation).
49
51
*
50
52
*/
51
53
public class SameDiffTransferLearningExample {
52
54
53
- // Used to figure out the shapes of variables, needed to figure out how many channels are going into our added Conv layer
54
- static class ShapeListener extends BaseListener {
55
-
56
- @ Override
57
- public boolean isActive (Operation operation ) {
58
- return true ;
59
- }
60
-
61
- @ Override
62
- public void activationAvailable (SameDiff sd , At at ,
63
- MultiDataSet batch , SameDiffOp op ,
64
- String varName , INDArray activation ) {
65
- System .out .println (varName + ": \t \t \t " + Arrays .toString (activation .shape ()));
66
-
67
- if (varName .endsWith ("Shape" )){
68
- System .out .println ("Shape value: " + activation );
69
- }
70
-
71
- }
72
- }
73
-
74
- /**
75
- * Does inception preprocessing on a batch of images. Takes an image with shape [batchSize, c, h, w]
76
- * and returns an image with shape [batchSize, height, width, c].
77
- *
78
- * @param height the height to resize to
79
- * @param width the width to resize to
80
- */
81
- public static INDArray batchInceptionPreprocessing (INDArray img , int height , int width ){
82
- // change to channels-last
83
- img = img .permute (0 , 2 , 3 , 1 );
84
-
85
- // normalize to 0-1
86
- img = img .div (256 );
87
-
88
- // resize
89
- INDArray preprocessedImage = Nd4j .createUninitialized (img .size (0 ), height , width , img .size (3 ));
90
-
91
- DynamicCustomOp op = DynamicCustomOp .builder ("resize_bilinear" )
92
- .addInputs (img )
93
- .addOutputs (preprocessedImage )
94
- .addIntegerArguments (height , width ).build ();
95
- Nd4j .exec (op );
96
-
97
- // finish preprocessing
98
- preprocessedImage = preprocessedImage .sub (0.5 );
99
- preprocessedImage = preprocessedImage .mul (2 );
100
- return preprocessedImage ;
101
- }
102
-
103
55
public static void main (String [] args ) throws Exception {
104
56
File modelFile = SameDiffTFImportMobileNetExample .downloadModel ();
105
57
@@ -112,7 +64,14 @@ public static void main(String[] args) throws Exception {
112
64
113
65
System .out .println ("\n \n " );
114
66
115
- // Print shapes for each activation
67
+
68
+ // We want to replace the last convolution layer and the output layer with our own ops, so we can fine tune the network
69
+ // These are the MobilenetV2/Logits and MobilenetV2/Predictions sections, respectively. See the printed summary.
70
+
71
+
72
+ // Print shapes for each activation.
73
+ // We need to know the shape (especially the channels) of the convolution op's input, so we know what shape to make the weight.
74
+ // We use a custom listener for this, see SameDiffCustomListenerExample
116
75
117
76
// INDArray test = new Cifar10DataSetIterator(10).next().getFeatures();
118
77
// test = batchInceptionPreprocessing(test, 224, 224);
@@ -123,23 +82,55 @@ public static void main(String[] args) throws Exception {
123
82
// .listeners(new ShapeListener())
124
83
// .execSingle();
125
84
126
- // get info for the last convolution layer (MobilenetV2/Logits)
85
+ // get info for the last convolution layer (MobilenetV2/Logits). We want to use an equivalent config.
127
86
Conv2D convOp = (Conv2D ) sd .getOpById ("MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D" );
128
87
System .out .println ("Conv config: " + convOp .getConfig ());
129
88
130
- // replace last convolution layer (MobilenetV2/Logits)
131
- sd = GraphTransformUtil .replaceSubgraphsMatching (sd ,
89
+ /*
90
+ The MobilenetV2/Logits section looks like:
91
+ MobilenetV2/Logits/AvgPool
92
+ MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D
93
+ MobilenetV2/Logits/Conv2d_1c_1x1/BiasAdd
94
+ MobilenetV2/Logits/Squeeze
95
+
96
+ We want to replace the convolution layer (Conv2D and BiasAdd) with our own, so we can fine tune it.
97
+
98
+
99
+ The SubGraphPredicate will select a subset of the graph by starting at the root node,
100
+ and then optionally applying SubGraphPredicate's for inputs.
101
+ Those SubGraphPredicate's can also add their inputs, etc.
102
+
103
+ The predicate will only accept a subgraph if it passes all the filters.
104
+ */
105
+
106
+ // Create a predicate for selecting the BiasAdd and Conv2D ops we want
107
+ SubGraphPredicate pred1 =
108
+ // Select the subgraph with root MobilenetV2/Logits/Conv2d_1c_1x1/BiasAdd
132
109
SubGraphPredicate .withRoot (OpPredicate .nameMatches ("MobilenetV2/Logits/Conv2d_1c_1x1/BiasAdd" ))
133
- .withInputSubgraph (0 , OpPredicate .nameMatches ("MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D" )),
134
- (sd1 , subGraph ) -> {
110
+ // Select (and require) the BiasAdd's 0th input to be MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D
111
+ .withInputSubgraph (0 , OpPredicate .nameMatches ("MobilenetV2/Logits/Conv2d_1c_1x1/Conv2D" ));
112
+
113
+
114
+ /*
115
+ Replace any subgraphs matching the predicate with our own subgraph
116
+ There will only be one match, but you can use SubGraphPredicate and GraphTransformUtil to replace many occurrences of the same subgraph.
135
117
118
+ The number of outputs from the replacement subgraph must match the number of outputs of the subgraph it is replacing.
119
+
120
+ Note that the graph isn't actually modified, a copy is made, modified, and then returned.
121
+ */
122
+ sd = GraphTransformUtil .replaceSubgraphsMatching (sd ,
123
+ pred1 ,
124
+ (sd1 , subGraph ) -> {
136
125
NameScope logits = sd1 .withNameScope ("Logits/Conv2D" );
137
126
138
127
// get the output of the AveragePooling op
139
128
SDVariable input = subGraph .inputs ().get (1 );
140
129
141
130
// we know the sizes from using the ShapeListener earlier
142
131
132
+ // We know what shape the weight needs to be from the input's channels and the config's kernel height and width.
133
+ // This is why we printed the shapes.
143
134
SDVariable w = sd1 .var ("W" , new XavierInitScheme ('c' , 5 * 5 * 8 , 10 ), DataType .FLOAT ,
144
135
1 , 1 , 1280 , 10 );
145
136
@@ -155,17 +146,38 @@ public static void main(String[] args) throws Exception {
155
146
return Collections .singletonList (output );
156
147
});
157
148
158
- // create SubGraphPredicate for selecting the MobilenetV2/Predictions ops
159
- SubGraphPredicate graphPred = SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Reshape_1" ))
160
- .withInputSubgraph (0 , SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Softmax" ))
161
- .withInputSubgraph (0 , SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Reshape" ))))
162
- .withInputSubgraph (1 , SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Shape" )));
149
+ /*
150
+ The MobilenetV2/Predictions section looks like:
151
+ MobilenetV2/Predictions/Reshape/shape
152
+ MobilenetV2/Predictions/Reshape
153
+ MobilenetV2/Predictions/Softmax
154
+ MobilenetV2/Predictions/Shape
155
+ MobilenetV2/Predictions/Reshape_1
163
156
164
- // replace the MobilenetV2/Predictions with our own softmax and loss
157
+ We want to replace the reshapes (unneeded and the wrong shape) and the softmax (we need a loss function and an output function).
158
+ You could keep the softmax, but there is no reason to.
159
+
160
+ We also need to add a labels input.
161
+
162
+ Note that this subgraph has no outputs, so neither should the replacement subgraph.
163
+ */
164
+
165
+ // create SubGraphPredicate for selecting the MobilenetV2/Predictions ops
166
+ SubGraphPredicate pred2 =
167
+ // Select a subgraph starting with the Reshape_1 op
168
+ SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Reshape_1" ))
169
+ // Add the 0th input to the subgraph if it is the specified Softmax Op
170
+ .withInputSubgraph (0 , SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Softmax" ))
171
+ // Add the 0th input of the Softmax op to the subgraph, as long as it is the specified Reshape op
172
+ .withInputSubgraph (0 , SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Reshape" ))))
173
+ // Add the 1st input to the subgraph if it is the specified Shape Op
174
+ .withInputSubgraph (1 , SubGraphPredicate .withRoot (OpPredicate .nameEquals ("MobilenetV2/Predictions/Shape" )));
175
+
176
+ // Replace any subgraphs matching the predicate with our own subgraph
177
+ // There will only be one match, but you can use SubGraphPredicate and GraphTransformUtil to replace many occurrences of the same subgraph
165
178
sd = GraphTransformUtil .replaceSubgraphsMatching (sd ,
166
- graphPred ,
179
+ pred2 ,
167
180
(sd1 , subGraph ) -> {
168
-
169
181
// placeholder for labels (needed for training)
170
182
SDVariable labels = sd1 .placeHolder ("label" , DataType .FLOAT , -1 , 10 );
171
183
@@ -186,8 +198,8 @@ public static void main(String[] args) throws Exception {
186
198
});
187
199
188
200
189
- // replace the input with input and inception preprocessing (except for resizing, which is done as part of the record reader)
190
- // can 't do this with GraphTransformUtil as it can't replace variables or re-use ops
201
+ // Add inception preprocessing to the input (except for resizing, which is done as part of the record reader)
202
+ // Can 't do this with GraphTransformUtil as it can't replace variables or re-use ops
191
203
192
204
SDVariable input = sd .getVariable ("input" );
193
205
@@ -200,6 +212,7 @@ public static void main(String[] args) throws Exception {
200
212
// change range to -1 - 1
201
213
SDVariable processed = normalized .sub (0.5 ).mul (2 );
202
214
215
+ // The 0th arg was input, replace it with the preprocessed input
203
216
sd .getOpById ("MobilenetV2/Conv/Conv2D" ).replaceArg (0 , processed );
204
217
205
218
@@ -254,4 +267,58 @@ public static void main(String[] args) throws Exception {
254
267
255
268
System .out .println ("Accuracy: " + acc );
256
269
}
270
+
271
+ /**
272
+ * Used to figure out the shapes of variables, needed to figure out how many channels are going into our added Conv layer
273
+ *
274
+ * See {@link SameDiffCustomListenerExample}
275
+ */
276
+ static class ShapeListener extends BaseListener {
277
+
278
+ @ Override
279
+ public boolean isActive (Operation operation ) {
280
+ return true ;
281
+ }
282
+
283
+ @ Override
284
+ public void activationAvailable (SameDiff sd , At at ,
285
+ MultiDataSet batch , SameDiffOp op ,
286
+ String varName , INDArray activation ) {
287
+ System .out .println (varName + ": \t \t \t " + Arrays .toString (activation .shape ()));
288
+
289
+ if (varName .endsWith ("Shape" )){
290
+ System .out .println ("Shape value: " + activation );
291
+ }
292
+
293
+ }
294
+ }
295
+
296
+ /**
297
+ * Does inception preprocessing on a batch of images. Takes an image with shape [batchSize, c, h, w]
298
+ * and returns an image with shape [batchSize, height, width, c].
299
+ *
300
+ * @param height the height to resize to
301
+ * @param width the width to resize to
302
+ */
303
+ public static INDArray batchInceptionPreprocessing (INDArray img , int height , int width ){
304
+ // change to channels-last
305
+ img = img .permute (0 , 2 , 3 , 1 );
306
+
307
+ // normalize to 0-1
308
+ img = img .div (256 );
309
+
310
+ // resize
311
+ INDArray preprocessedImage = Nd4j .createUninitialized (img .size (0 ), height , width , img .size (3 ));
312
+
313
+ DynamicCustomOp op = DynamicCustomOp .builder ("resize_bilinear" )
314
+ .addInputs (img )
315
+ .addOutputs (preprocessedImage )
316
+ .addIntegerArguments (height , width ).build ();
317
+ Nd4j .exec (op );
318
+
319
+ // finish preprocessing
320
+ preprocessedImage = preprocessedImage .sub (0.5 );
321
+ preprocessedImage = preprocessedImage .mul (2 );
322
+ return preprocessedImage ;
323
+ }
257
324
}
0 commit comments