@@ -62,104 +62,141 @@ public ComputationGraph getModel() {
62
62
return model ;
63
63
}
64
64
65
+ public int getGrowthRate () {
66
+ return growthRate ;
67
+ }
68
+
65
69
public String initLayer (int kernel , int stride , int padding , int channels ) {
66
- String init = "initConv" ;
67
- String initPool = "initPool" ;
68
- conf .addLayer (init , new ConvolutionLayer .Builder ()
70
+ ConvolutionLayer convolutionLayer = new ConvolutionLayer .Builder ()
71
+ .name ("initConv" )
69
72
.kernelSize (kernel , kernel )
70
73
.stride (stride , stride )
71
74
.padding (padding , padding )
72
75
.nIn (channels )
73
76
.nOut (growthRate * 2 )
74
- .build (), "input" );
75
- conf .addLayer (initPool , new Pooling2D .Builder (SubsamplingLayer .PoolingType .MAX )
76
- .kernelSize (2 , 2 )
77
+ .build ();
78
+ SubsamplingLayer subsamplingLayer = new Pooling2D .Builder (SubsamplingLayer .PoolingType .MAX )
79
+ .name ("initPool" )
80
+ .kernelSize (3 , 3 )
77
81
.padding (0 , 0 )
78
- .build (), init );
79
- return initPool ;
80
- }
82
+ .build ();
83
+
84
+ conf .addLayer (convolutionLayer .getLayerName (), convolutionLayer , "input" );
85
+ conf .addLayer (subsamplingLayer .getLayerName (), subsamplingLayer , convolutionLayer .getLayerName ());
81
86
82
- public String addTransitionLayer (String transitionName , int numIn , String ... previousBlock ) {
83
- String bnName = "bn_" + transitionName ;
84
- String convName = "conv_" + transitionName ;
85
- String poolName = "pool_" + transitionName ;
87
+ return subsamplingLayer .getLayerName ();
88
+ }
86
89
87
- conf .addLayer (bnName , new BatchNormalization .Builder ()
88
- .build (), previousBlock );
89
- conf .addLayer (convName , new ConvolutionLayer .Builder ()
90
+ public String addTransitionLayer (String name , long numIn , List <String > previousLayers ) {
91
+ BatchNormalization bnLayer = new BatchNormalization .Builder ()
92
+ .name (String .format ("%s_%s" , name , "bn" ))
93
+ .build ();
94
+ ConvolutionLayer layer1x1 = new ConvolutionLayer .Builder ()
95
+ .name (String .format ("%s_%s" , name , "conv" ))
90
96
.kernelSize (1 , 1 )
91
97
.stride (1 , 1 )
92
98
.padding (0 , 0 )
93
99
.nOut (numIn / 2 )
94
- .build (), bnName );
95
- conf .addLayer (poolName , new Pooling2D .Builder (SubsamplingLayer .PoolingType .AVG )
100
+ .build ();
101
+ SubsamplingLayer subsamplingLayer = new Pooling2D .Builder (SubsamplingLayer .PoolingType .AVG )
102
+ .name (String .format ("%s_%s" , name , "pool" ))
96
103
.kernelSize (2 , 2 )
97
104
.padding (0 , 0 )
98
- .build (), convName );
99
- ;
105
+ .build ();
100
106
101
- return poolName ;
102
- }
103
-
104
- private String [] addDenseLayer (boolean firstLayerInBlock , String layerName , String ... previousLayers ) {
105
- String bnName = "bn1_" + layerName ;
106
- String convName = "conv1_" + layerName ;
107
- String bnName2 = "bn2_" + layerName ;
108
- String convName2 = "conv2_" + layerName ;
107
+ conf .addLayer (bnLayer .getLayerName (), bnLayer , previousLayers .toArray (String []::new ));
108
+ conf .addLayer (layer1x1 .getLayerName (), layer1x1 , bnLayer .getLayerName ());
109
+ conf .addLayer (subsamplingLayer .getLayerName (), subsamplingLayer , layer1x1 .getLayerName ());
109
110
110
- if (useBottleNeck ) {
111
- conf .addLayer (bnName , new BatchNormalization .Builder ()
112
- .build (), previousLayers );
113
- conf .addLayer (convName , new ConvolutionLayer .Builder ()
114
- .kernelSize (1 , 1 )
115
- .stride (1 , 1 )
116
- .padding (0 , 0 )
117
- .nOut (growthRate * 2 )
118
- .build (), bnName );
119
- }
111
+ return subsamplingLayer .getLayerName ();
112
+ }
120
113
121
- conf .addLayer (bnName2 , new BatchNormalization .Builder ()
122
- .build (), useBottleNeck ? new String []{convName } : previousLayers );
123
- conf .addLayer (convName2 , new ConvolutionLayer .Builder ()
114
+ private ConvolutionLayer addDenseLayer (String name , String ... previousLayers ) {
115
+ BatchNormalization bnLayer1 = new BatchNormalization .Builder ()
116
+ .name (String .format ("%s_%s" , name , "bn1" ))
117
+ .build ();
118
+ ConvolutionLayer layer1x1 = new ConvolutionLayer .Builder ()
119
+ .name (String .format ("%s_%s" , name , "con1x1" ))
120
+ .kernelSize (1 , 1 )
121
+ .stride (1 , 1 )
122
+ .padding (0 , 0 )
123
+ .nOut (growthRate * 4 )
124
+ .build ();
125
+ BatchNormalization bnLayer2 = new BatchNormalization .Builder ()
126
+ .name (String .format ("%s_%s" , name , "bn2" ))
127
+ .build ();
128
+ ConvolutionLayer layer3x3 = new ConvolutionLayer .Builder ()
129
+ .name (String .format ("%s_%s" , name , "con3x3" ))
124
130
.kernelSize (3 , 3 )
125
131
.stride (1 , 1 )
126
132
.padding (1 , 1 )
127
133
.nOut (growthRate )
128
- .build (), bnName2 );
134
+ .build ();
135
+
136
+ if (useBottleNeck ) {
137
+ conf .addLayer (bnLayer1 .getLayerName (), bnLayer1 , previousLayers );
138
+ conf .addLayer (layer1x1 .getLayerName (), layer1x1 , bnLayer1 .getLayerName ());
139
+ conf .addLayer (bnLayer2 .getLayerName (), bnLayer2 , layer1x1 .getLayerName ());
140
+ } else {
141
+ conf .addLayer (bnLayer2 .getLayerName (), bnLayer2 , previousLayers );
142
+ }
143
+ conf .addLayer (layer3x3 .getLayerName (), layer3x3 , bnLayer2 .getLayerName ());
129
144
130
- return firstLayerInBlock ? new String []{ convName2 } : increaseArray ( convName2 , previousLayers ) ;
145
+ return layer3x3 ;
131
146
}
132
147
133
- public String [] addDenseBlock (int numLayers , boolean first , String blockName , String [] previousLayer ) {
134
- String layerName = blockName + "_lay" + numLayers ;
135
- String [] layersInput = addDenseLayer (first , layerName , previousLayer );
136
- --numLayers ;
137
- if (numLayers > 0 ) {
138
- layersInput = addDenseBlock (numLayers , false , blockName , layersInput );
148
+ protected List <String > buildDenseBlock (String blockName , int numLayers , String lastLayerName ) {
149
+ List <ConvolutionLayer > layers = new ArrayList <>();
150
+ for (int i = 0 ; i < numLayers ; ++i ) {
151
+ layers .add (addDenseLayer (String .format ("%s_%s" , blockName , i ), increaseArray (lastLayerName , getLayerNames (layers ))));
139
152
}
140
- return first ? increaseArray (previousLayer [0 ], layersInput ) : layersInput ;
153
+ List <String > names = new ArrayList <>(Arrays .asList (getLayerNames (layers )));
154
+ names .add (lastLayerName );
155
+ return names ;
141
156
}
142
157
143
- public void addOutputLayer (int height , int width , int numIn , int numLabels , String ... previousLayer ) {
144
- conf .addLayer ("lastBatch" , new BatchNormalization .Builder ()
145
- .build (), previousLayer );
146
- conf .addLayer ("GAP" , new GlobalPoolingLayer .Builder ()
158
+ public void addOutputLayer (int numIn , int numLabels , String ... previousLayer ) {
159
+ GlobalPoolingLayer globalPoolingLayer = new GlobalPoolingLayer .Builder ()
160
+ .name ("outputGPL" )
147
161
.poolingType (PoolingType .AVG )
148
- .build (), "lastBatch" );
149
- conf .addLayer ("dense" , new DenseLayer .Builder ()
150
- .nIn (numIn )
151
- .nOut (1024 )
152
- .build (), "GAP" );
153
- conf .addLayer ("output" , new OutputLayer .Builder (LossFunctions .LossFunction .NEGATIVELOGLIKELIHOOD )
162
+ .collapseDimensions (false )
163
+ .build ();
164
+ BatchNormalization bn2 = new BatchNormalization .Builder ()
165
+ .name ("outputBn" )
166
+ .build ();
167
+ ConvolutionLayer convolutionLayer2 = new ConvolutionLayer .Builder ()
168
+ .name ("outputConv" )
169
+ .kernelSize (1 , 1 )
170
+ .stride (1 , 1 )
171
+ .padding (0 , 0 )
172
+ .nOut (numIn * 2 )
173
+ .build ();
174
+ OutputLayer outputLayer = new OutputLayer .Builder (LossFunctions .LossFunction .NEGATIVELOGLIKELIHOOD )
175
+ .name ("output" )
154
176
.nOut (numLabels )
155
177
.activation (Activation .SOFTMAX )
156
- .build (), "dense" );
178
+ .build ();
179
+
180
+ conf .addLayer (globalPoolingLayer .getLayerName (), globalPoolingLayer , previousLayer );
181
+ conf .addLayer (bn2 .getLayerName (), bn2 , globalPoolingLayer .getLayerName ());
182
+ conf .addLayer (convolutionLayer2 .getLayerName (), convolutionLayer2 , bn2 .getLayerName ());
183
+ conf .addLayer (outputLayer .getLayerName (), outputLayer , convolutionLayer2 .getLayerName ());
157
184
}
158
185
159
- private String [] increaseArray (String newLayer , String ... theArray ) {
186
+ private String [] increaseArray (String newLayer , String [] theArray ) {
160
187
String [] newArray = new String [theArray .length + 1 ];
161
188
System .arraycopy (theArray , 0 , newArray , 0 , theArray .length );
162
189
newArray [theArray .length ] = newLayer ;
163
190
return newArray ;
164
191
}
192
+
193
+ protected String [] getLayerNames (List <ConvolutionLayer > theArray ) {
194
+ List <String > names = new ArrayList <>();
195
+ if (theArray != null ) {
196
+ for (ConvolutionLayer convolutionLayer : theArray ) {
197
+ names .add (convolutionLayer .getLayerName ());
198
+ }
199
+ }
200
+ return names .toArray (String []::new );
201
+ }
165
202
}
0 commit comments