Skip to content

Commit 0e788ea

Browse files
committed
Fixed DenseNet example
Signed-off-by: Aleksandar <[email protected]>
1 parent 9e971fd commit 0e788ea

File tree

3 files changed

+118
-74
lines changed

3 files changed

+118
-74
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/densenet/DenseNetMain.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public class DenseNetMain {
6666
private static final int height = 227;
6767
private static final int width = 227;
6868
private static final int channels = 3;
69-
private static final int batchSize = 10;
69+
private static final int batchSize = 32;
7070
private static final int outputNum = 4;
7171
private static final int numEpochs = 1000;
7272
private static final double splitTrainTest = 0.8;

dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/densenet/model/DenseNetBuilder.java

Lines changed: 99 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -62,104 +62,141 @@ public ComputationGraph getModel() {
6262
return model;
6363
}
6464

65+
public int getGrowthRate() {
66+
return growthRate;
67+
}
68+
6569
public String initLayer(int kernel, int stride, int padding, int channels) {
66-
String init = "initConv";
67-
String initPool = "initPool";
68-
conf.addLayer(init, new ConvolutionLayer.Builder()
70+
ConvolutionLayer convolutionLayer = new ConvolutionLayer.Builder()
71+
.name("initConv")
6972
.kernelSize(kernel, kernel)
7073
.stride(stride, stride)
7174
.padding(padding, padding)
7275
.nIn(channels)
7376
.nOut(growthRate * 2)
74-
.build(), "input");
75-
conf.addLayer(initPool, new Pooling2D.Builder(SubsamplingLayer.PoolingType.MAX)
76-
.kernelSize(2, 2)
77+
.build();
78+
SubsamplingLayer subsamplingLayer = new Pooling2D.Builder(SubsamplingLayer.PoolingType.MAX)
79+
.name("initPool")
80+
.kernelSize(3, 3)
7781
.padding(0, 0)
78-
.build(), init);
79-
return initPool;
80-
}
82+
.build();
83+
84+
conf.addLayer(convolutionLayer.getLayerName(), convolutionLayer, "input");
85+
conf.addLayer(subsamplingLayer.getLayerName(), subsamplingLayer, convolutionLayer.getLayerName());
8186

82-
public String addTransitionLayer(String transitionName, int numIn, String... previousBlock) {
83-
String bnName = "bn_" + transitionName;
84-
String convName = "conv_" + transitionName;
85-
String poolName = "pool_" + transitionName;
87+
return subsamplingLayer.getLayerName();
88+
}
8689

87-
conf.addLayer(bnName, new BatchNormalization.Builder()
88-
.build(), previousBlock);
89-
conf.addLayer(convName, new ConvolutionLayer.Builder()
90+
public String addTransitionLayer(String name, long numIn, List<String> previousLayers) {
91+
BatchNormalization bnLayer = new BatchNormalization.Builder()
92+
.name(String.format("%s_%s", name, "bn"))
93+
.build();
94+
ConvolutionLayer layer1x1 = new ConvolutionLayer.Builder()
95+
.name(String.format("%s_%s", name, "conv"))
9096
.kernelSize(1, 1)
9197
.stride(1, 1)
9298
.padding(0, 0)
9399
.nOut(numIn / 2)
94-
.build(), bnName);
95-
conf.addLayer(poolName, new Pooling2D.Builder(SubsamplingLayer.PoolingType.AVG)
100+
.build();
101+
SubsamplingLayer subsamplingLayer = new Pooling2D.Builder(SubsamplingLayer.PoolingType.AVG)
102+
.name(String.format("%s_%s", name, "pool"))
96103
.kernelSize(2, 2)
97104
.padding(0, 0)
98-
.build(), convName);
99-
;
105+
.build();
100106

101-
return poolName;
102-
}
103-
104-
private String[] addDenseLayer(boolean firstLayerInBlock, String layerName, String... previousLayers) {
105-
String bnName = "bn1_" + layerName;
106-
String convName = "conv1_" + layerName;
107-
String bnName2 = "bn2_" + layerName;
108-
String convName2 = "conv2_" + layerName;
107+
conf.addLayer(bnLayer.getLayerName(), bnLayer, previousLayers.toArray(String[]::new));
108+
conf.addLayer(layer1x1.getLayerName(), layer1x1, bnLayer.getLayerName());
109+
conf.addLayer(subsamplingLayer.getLayerName(), subsamplingLayer, layer1x1.getLayerName());
109110

110-
if (useBottleNeck) {
111-
conf.addLayer(bnName, new BatchNormalization.Builder()
112-
.build(), previousLayers);
113-
conf.addLayer(convName, new ConvolutionLayer.Builder()
114-
.kernelSize(1, 1)
115-
.stride(1, 1)
116-
.padding(0, 0)
117-
.nOut(growthRate * 2)
118-
.build(), bnName);
119-
}
111+
return subsamplingLayer.getLayerName();
112+
}
120113

121-
conf.addLayer(bnName2, new BatchNormalization.Builder()
122-
.build(), useBottleNeck ? new String[]{convName} : previousLayers);
123-
conf.addLayer(convName2, new ConvolutionLayer.Builder()
114+
private ConvolutionLayer addDenseLayer(String name, String... previousLayers) {
115+
BatchNormalization bnLayer1 = new BatchNormalization.Builder()
116+
.name(String.format("%s_%s", name, "bn1"))
117+
.build();
118+
ConvolutionLayer layer1x1 = new ConvolutionLayer.Builder()
119+
.name(String.format("%s_%s", name, "con1x1"))
120+
.kernelSize(1, 1)
121+
.stride(1, 1)
122+
.padding(0, 0)
123+
.nOut(growthRate * 4)
124+
.build();
125+
BatchNormalization bnLayer2 = new BatchNormalization.Builder()
126+
.name(String.format("%s_%s", name, "bn2"))
127+
.build();
128+
ConvolutionLayer layer3x3 = new ConvolutionLayer.Builder()
129+
.name(String.format("%s_%s", name, "con3x3"))
124130
.kernelSize(3, 3)
125131
.stride(1, 1)
126132
.padding(1, 1)
127133
.nOut(growthRate)
128-
.build(), bnName2);
134+
.build();
135+
136+
if (useBottleNeck) {
137+
conf.addLayer(bnLayer1.getLayerName(), bnLayer1, previousLayers);
138+
conf.addLayer(layer1x1.getLayerName(), layer1x1, bnLayer1.getLayerName());
139+
conf.addLayer(bnLayer2.getLayerName(), bnLayer2, layer1x1.getLayerName());
140+
} else {
141+
conf.addLayer(bnLayer2.getLayerName(), bnLayer2, previousLayers);
142+
}
143+
conf.addLayer(layer3x3.getLayerName(), layer3x3, bnLayer2.getLayerName());
129144

130-
return firstLayerInBlock ? new String[]{convName2} : increaseArray(convName2, previousLayers);
145+
return layer3x3;
131146
}
132147

133-
public String[] addDenseBlock(int numLayers, boolean first, String blockName, String[] previousLayer) {
134-
String layerName = blockName + "_lay" + numLayers;
135-
String[] layersInput = addDenseLayer(first, layerName, previousLayer);
136-
--numLayers;
137-
if (numLayers > 0) {
138-
layersInput = addDenseBlock(numLayers, false, blockName, layersInput);
148+
protected List<String> buildDenseBlock(String blockName, int numLayers, String lastLayerName) {
149+
List<ConvolutionLayer> layers = new ArrayList<>();
150+
for (int i = 0; i < numLayers; ++i) {
151+
layers.add(addDenseLayer(String.format("%s_%s", blockName, i), increaseArray(lastLayerName, getLayerNames(layers))));
139152
}
140-
return first ? increaseArray(previousLayer[0], layersInput) : layersInput;
153+
List<String> names = new ArrayList<>(Arrays.asList(getLayerNames(layers)));
154+
names.add(lastLayerName);
155+
return names;
141156
}
142157

143-
public void addOutputLayer(int height, int width, int numIn, int numLabels, String... previousLayer) {
144-
conf.addLayer("lastBatch", new BatchNormalization.Builder()
145-
.build(), previousLayer);
146-
conf.addLayer("GAP", new GlobalPoolingLayer.Builder()
158+
public void addOutputLayer(int numIn, int numLabels, String... previousLayer) {
159+
GlobalPoolingLayer globalPoolingLayer = new GlobalPoolingLayer.Builder()
160+
.name("outputGPL")
147161
.poolingType(PoolingType.AVG)
148-
.build(), "lastBatch");
149-
conf.addLayer("dense", new DenseLayer.Builder()
150-
.nIn(numIn)
151-
.nOut(1024)
152-
.build(), "GAP");
153-
conf.addLayer("output", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
162+
.collapseDimensions(false)
163+
.build();
164+
BatchNormalization bn2 = new BatchNormalization.Builder()
165+
.name("outputBn")
166+
.build();
167+
ConvolutionLayer convolutionLayer2 = new ConvolutionLayer.Builder()
168+
.name("outputConv")
169+
.kernelSize(1, 1)
170+
.stride(1, 1)
171+
.padding(0, 0)
172+
.nOut(numIn * 2)
173+
.build();
174+
OutputLayer outputLayer = new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
175+
.name("output")
154176
.nOut(numLabels)
155177
.activation(Activation.SOFTMAX)
156-
.build(), "dense");
178+
.build();
179+
180+
conf.addLayer(globalPoolingLayer.getLayerName(), globalPoolingLayer, previousLayer);
181+
conf.addLayer(bn2.getLayerName(), bn2, globalPoolingLayer.getLayerName());
182+
conf.addLayer(convolutionLayer2.getLayerName(), convolutionLayer2, bn2.getLayerName());
183+
conf.addLayer(outputLayer.getLayerName(), outputLayer, convolutionLayer2.getLayerName());
157184
}
158185

159-
private String[] increaseArray(String newLayer, String... theArray) {
186+
private String[] increaseArray(String newLayer, String[] theArray) {
160187
String[] newArray = new String[theArray.length + 1];
161188
System.arraycopy(theArray, 0, newArray, 0, theArray.length);
162189
newArray[theArray.length] = newLayer;
163190
return newArray;
164191
}
192+
193+
protected String[] getLayerNames(List<ConvolutionLayer> theArray) {
194+
List<String> names = new ArrayList<>();
195+
if (theArray != null) {
196+
for (ConvolutionLayer convolutionLayer : theArray) {
197+
names.add(convolutionLayer.getLayerName());
198+
}
199+
}
200+
return names.toArray(String[]::new);
201+
}
165202
}

dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/densenet/model/DenseNetModel.java

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,24 @@ public static DenseNetModel getInstance() {
3232

3333
public ComputationGraph buildNetwork(long seed, int channels, int numLabels, int width, int height) {
3434

35-
DenseNetBuilder denseNetModel = new DenseNetBuilder(height, width, channels, seed, 12, false); //227x227x3
36-
37-
String init = denseNetModel.initLayer(7, 2, 1, channels); //56x56x24
38-
String[] block1 = denseNetModel.addDenseBlock(6, true, "db1", new String[]{init});
39-
String trans1 = denseNetModel.addTransitionLayer("tr1", 96, block1); //28x28x48
40-
String[] block2 = denseNetModel.addDenseBlock(12, true, "db2", new String[]{trans1});
41-
String trans2 = denseNetModel.addTransitionLayer("tr2", 192, block2); //14x14x96
42-
String[] block3 = denseNetModel.addDenseBlock(24, true, "db3", new String[]{trans2});
43-
String trans3 = denseNetModel.addTransitionLayer("tr3", 384, block3); //7x7x192
44-
String[] block4 = denseNetModel.addDenseBlock(16, true, "db4", new String[]{trans3});
45-
denseNetModel.addOutputLayer(7, 7, 384, numLabels, block4);
35+
DenseNetBuilder denseNetModel = new DenseNetBuilder(height, width, channels, seed, 12, true); //227x227x3
36+
37+
int l1 = 6, l2 = 12, l3 = 24, l4 = 16;
38+
39+
int nIn1 = l1 * denseNetModel.getGrowthRate() + 2 * denseNetModel.getGrowthRate();
40+
int nIn2 = l2 * denseNetModel.getGrowthRate() + nIn1 / 2;
41+
int nIn3 = l3 * denseNetModel.getGrowthRate() + nIn2 / 2;
42+
int nIn4 = l4 * denseNetModel.getGrowthRate() + nIn3 / 2;
43+
44+
String init = denseNetModel.initLayer(5, 2, 1, channels);
45+
List<String> block1 = denseNetModel.buildDenseBlock("b1", l1, init);
46+
String trans1 = denseNetModel.addTransitionLayer("t1", nIn1, block1);
47+
List<String> block2 = denseNetModel.buildDenseBlock("b2", l2, trans1);
48+
String trans2 = denseNetModel.addTransitionLayer("t2", nIn2, block2);
49+
List<String> block3 = denseNetModel.buildDenseBlock("b3", l3, trans2);
50+
String trans3 = denseNetModel.addTransitionLayer("t3", nIn3, block3);
51+
List<String> block4 = denseNetModel.buildDenseBlock("b4", l4, trans3);
52+
denseNetModel.addOutputLayer(nIn4, numLabels, block4.toArray(String[]::new));
4653

4754
return denseNetModel.getModel();
4855
}

0 commit comments

Comments
 (0)