@@ -62,104 +62,141 @@ public ComputationGraph getModel() {
6262 return model ;
6363 }
6464
65+ public int getGrowthRate () {
66+ return growthRate ;
67+ }
68+
6569 public String initLayer (int kernel , int stride , int padding , int channels ) {
66- String init = "initConv" ;
67- String initPool = "initPool" ;
68- conf .addLayer (init , new ConvolutionLayer .Builder ()
70+ ConvolutionLayer convolutionLayer = new ConvolutionLayer .Builder ()
71+ .name ("initConv" )
6972 .kernelSize (kernel , kernel )
7073 .stride (stride , stride )
7174 .padding (padding , padding )
7275 .nIn (channels )
7376 .nOut (growthRate * 2 )
74- .build (), "input" );
75- conf .addLayer (initPool , new Pooling2D .Builder (SubsamplingLayer .PoolingType .MAX )
76- .kernelSize (2 , 2 )
77+ .build ();
78+ SubsamplingLayer subsamplingLayer = new Pooling2D .Builder (SubsamplingLayer .PoolingType .MAX )
79+ .name ("initPool" )
80+ .kernelSize (3 , 3 )
7781 .padding (0 , 0 )
78- .build (), init );
79- return initPool ;
80- }
82+ .build ();
83+
84+ conf .addLayer (convolutionLayer .getLayerName (), convolutionLayer , "input" );
85+ conf .addLayer (subsamplingLayer .getLayerName (), subsamplingLayer , convolutionLayer .getLayerName ());
8186
82- public String addTransitionLayer (String transitionName , int numIn , String ... previousBlock ) {
83- String bnName = "bn_" + transitionName ;
84- String convName = "conv_" + transitionName ;
85- String poolName = "pool_" + transitionName ;
87+ return subsamplingLayer .getLayerName ();
88+ }
8689
87- conf .addLayer (bnName , new BatchNormalization .Builder ()
88- .build (), previousBlock );
89- conf .addLayer (convName , new ConvolutionLayer .Builder ()
90+ public String addTransitionLayer (String name , long numIn , List <String > previousLayers ) {
91+ BatchNormalization bnLayer = new BatchNormalization .Builder ()
92+ .name (String .format ("%s_%s" , name , "bn" ))
93+ .build ();
94+ ConvolutionLayer layer1x1 = new ConvolutionLayer .Builder ()
95+ .name (String .format ("%s_%s" , name , "conv" ))
9096 .kernelSize (1 , 1 )
9197 .stride (1 , 1 )
9298 .padding (0 , 0 )
9399 .nOut (numIn / 2 )
94- .build (), bnName );
95- conf .addLayer (poolName , new Pooling2D .Builder (SubsamplingLayer .PoolingType .AVG )
100+ .build ();
101+ SubsamplingLayer subsamplingLayer = new Pooling2D .Builder (SubsamplingLayer .PoolingType .AVG )
102+ .name (String .format ("%s_%s" , name , "pool" ))
96103 .kernelSize (2 , 2 )
97104 .padding (0 , 0 )
98- .build (), convName );
99- ;
105+ .build ();
100106
101- return poolName ;
102- }
103-
104- private String [] addDenseLayer (boolean firstLayerInBlock , String layerName , String ... previousLayers ) {
105- String bnName = "bn1_" + layerName ;
106- String convName = "conv1_" + layerName ;
107- String bnName2 = "bn2_" + layerName ;
108- String convName2 = "conv2_" + layerName ;
107+ conf .addLayer (bnLayer .getLayerName (), bnLayer , previousLayers .toArray (String []::new ));
108+ conf .addLayer (layer1x1 .getLayerName (), layer1x1 , bnLayer .getLayerName ());
109+ conf .addLayer (subsamplingLayer .getLayerName (), subsamplingLayer , layer1x1 .getLayerName ());
109110
110- if (useBottleNeck ) {
111- conf .addLayer (bnName , new BatchNormalization .Builder ()
112- .build (), previousLayers );
113- conf .addLayer (convName , new ConvolutionLayer .Builder ()
114- .kernelSize (1 , 1 )
115- .stride (1 , 1 )
116- .padding (0 , 0 )
117- .nOut (growthRate * 2 )
118- .build (), bnName );
119- }
111+ return subsamplingLayer .getLayerName ();
112+ }
120113
121- conf .addLayer (bnName2 , new BatchNormalization .Builder ()
122- .build (), useBottleNeck ? new String []{convName } : previousLayers );
123- conf .addLayer (convName2 , new ConvolutionLayer .Builder ()
114+ private ConvolutionLayer addDenseLayer (String name , String ... previousLayers ) {
115+ BatchNormalization bnLayer1 = new BatchNormalization .Builder ()
116+ .name (String .format ("%s_%s" , name , "bn1" ))
117+ .build ();
118+ ConvolutionLayer layer1x1 = new ConvolutionLayer .Builder ()
119+ .name (String .format ("%s_%s" , name , "con1x1" ))
120+ .kernelSize (1 , 1 )
121+ .stride (1 , 1 )
122+ .padding (0 , 0 )
123+ .nOut (growthRate * 4 )
124+ .build ();
125+ BatchNormalization bnLayer2 = new BatchNormalization .Builder ()
126+ .name (String .format ("%s_%s" , name , "bn2" ))
127+ .build ();
128+ ConvolutionLayer layer3x3 = new ConvolutionLayer .Builder ()
129+ .name (String .format ("%s_%s" , name , "con3x3" ))
124130 .kernelSize (3 , 3 )
125131 .stride (1 , 1 )
126132 .padding (1 , 1 )
127133 .nOut (growthRate )
128- .build (), bnName2 );
134+ .build ();
135+
136+ if (useBottleNeck ) {
137+ conf .addLayer (bnLayer1 .getLayerName (), bnLayer1 , previousLayers );
138+ conf .addLayer (layer1x1 .getLayerName (), layer1x1 , bnLayer1 .getLayerName ());
139+ conf .addLayer (bnLayer2 .getLayerName (), bnLayer2 , layer1x1 .getLayerName ());
140+ } else {
141+ conf .addLayer (bnLayer2 .getLayerName (), bnLayer2 , previousLayers );
142+ }
143+ conf .addLayer (layer3x3 .getLayerName (), layer3x3 , bnLayer2 .getLayerName ());
129144
130- return firstLayerInBlock ? new String []{ convName2 } : increaseArray ( convName2 , previousLayers ) ;
145+ return layer3x3 ;
131146 }
132147
133- public String [] addDenseBlock (int numLayers , boolean first , String blockName , String [] previousLayer ) {
134- String layerName = blockName + "_lay" + numLayers ;
135- String [] layersInput = addDenseLayer (first , layerName , previousLayer );
136- --numLayers ;
137- if (numLayers > 0 ) {
138- layersInput = addDenseBlock (numLayers , false , blockName , layersInput );
148+ protected List <String > buildDenseBlock (String blockName , int numLayers , String lastLayerName ) {
149+ List <ConvolutionLayer > layers = new ArrayList <>();
150+ for (int i = 0 ; i < numLayers ; ++i ) {
151+ layers .add (addDenseLayer (String .format ("%s_%s" , blockName , i ), increaseArray (lastLayerName , getLayerNames (layers ))));
139152 }
140- return first ? increaseArray (previousLayer [0 ], layersInput ) : layersInput ;
153+ List <String > names = new ArrayList <>(Arrays .asList (getLayerNames (layers )));
154+ names .add (lastLayerName );
155+ return names ;
141156 }
142157
143- public void addOutputLayer (int height , int width , int numIn , int numLabels , String ... previousLayer ) {
144- conf .addLayer ("lastBatch" , new BatchNormalization .Builder ()
145- .build (), previousLayer );
146- conf .addLayer ("GAP" , new GlobalPoolingLayer .Builder ()
158+ public void addOutputLayer (int numIn , int numLabels , String ... previousLayer ) {
159+ GlobalPoolingLayer globalPoolingLayer = new GlobalPoolingLayer .Builder ()
160+ .name ("outputGPL" )
147161 .poolingType (PoolingType .AVG )
148- .build (), "lastBatch" );
149- conf .addLayer ("dense" , new DenseLayer .Builder ()
150- .nIn (numIn )
151- .nOut (1024 )
152- .build (), "GAP" );
153- conf .addLayer ("output" , new OutputLayer .Builder (LossFunctions .LossFunction .NEGATIVELOGLIKELIHOOD )
162+ .collapseDimensions (false )
163+ .build ();
164+ BatchNormalization bn2 = new BatchNormalization .Builder ()
165+ .name ("outputBn" )
166+ .build ();
167+ ConvolutionLayer convolutionLayer2 = new ConvolutionLayer .Builder ()
168+ .name ("outputConv" )
169+ .kernelSize (1 , 1 )
170+ .stride (1 , 1 )
171+ .padding (0 , 0 )
172+ .nOut (numIn * 2 )
173+ .build ();
174+ OutputLayer outputLayer = new OutputLayer .Builder (LossFunctions .LossFunction .NEGATIVELOGLIKELIHOOD )
175+ .name ("output" )
154176 .nOut (numLabels )
155177 .activation (Activation .SOFTMAX )
156- .build (), "dense" );
178+ .build ();
179+
180+ conf .addLayer (globalPoolingLayer .getLayerName (), globalPoolingLayer , previousLayer );
181+ conf .addLayer (bn2 .getLayerName (), bn2 , globalPoolingLayer .getLayerName ());
182+ conf .addLayer (convolutionLayer2 .getLayerName (), convolutionLayer2 , bn2 .getLayerName ());
183+ conf .addLayer (outputLayer .getLayerName (), outputLayer , convolutionLayer2 .getLayerName ());
157184 }
158185
159- private String [] increaseArray (String newLayer , String ... theArray ) {
186+ private String [] increaseArray (String newLayer , String [] theArray ) {
160187 String [] newArray = new String [theArray .length + 1 ];
161188 System .arraycopy (theArray , 0 , newArray , 0 , theArray .length );
162189 newArray [theArray .length ] = newLayer ;
163190 return newArray ;
164191 }
192+
193+ protected String [] getLayerNames (List <ConvolutionLayer > theArray ) {
194+ List <String > names = new ArrayList <>();
195+ if (theArray != null ) {
196+ for (ConvolutionLayer convolutionLayer : theArray ) {
197+ names .add (convolutionLayer .getLayerName ());
198+ }
199+ }
200+ return names .toArray (String []::new );
201+ }
165202}
0 commit comments