17
17
from keras import backend as K
18
18
from keras import metrics
19
19
from keras .models import Sequential
20
- from keras .layers import Dense , Dropout , LocallyConnected1D , MaxPooling1D , Flatten
20
+ from keras .layers import Dense , Dropout , LocallyConnected1D , Convolution1D , MaxPooling1D , Flatten
21
21
from keras .callbacks import Callback , ModelCheckpoint , ProgbarLogger
22
22
23
23
from sklearn .preprocessing import Imputer
65
65
DENSE_LAYERS = [D1 , D2 , D3 , D4 ]
66
66
67
67
# Number of units per locally connected layer
68
- LC1 = 10 , 1 # nb_filter, filter_length
68
+ LC1 = 10 , 10 # nb_filter, filter_length
69
69
LC2 = 0 , 0 # disabled layer
70
70
# LOCALLY_CONNECTED_LAYERS = list(LC1 + LC2)
71
71
LOCALLY_CONNECTED_LAYERS = [0 , 0 ]
@@ -91,6 +91,9 @@ def get_parser():
91
91
parser .add_argument ("-b" , "--batch_size" , action = "store" ,
92
92
default = BATCH_SIZE , type = int ,
93
93
help = "batch size" )
94
+ parser .add_argument ("-c" , "--convolution" , action = "store_true" ,
95
+ default = False ,
96
+ help = "use convolution layers instead of locally connected layers" )
94
97
parser .add_argument ("-d" , "--dense" , action = "store" , nargs = '+' , type = int ,
95
98
default = DENSE_LAYERS ,
96
99
help = "number of units in fully connected layers in an integer array" )
@@ -161,13 +164,14 @@ def extension_from_parameters(args):
161
164
if args .feature_subsample :
162
165
ext += '.F={}' .format (args .feature_subsample )
163
166
if args .locally_connected :
167
+ name = 'C' if args .convolution else 'LC'
164
168
layer_list = list (range (0 , len (args .locally_connected ), 2 ))
165
169
for l , i in enumerate (layer_list ):
166
170
nb_filter = args .locally_connected [i ]
167
171
filter_len = args .locally_connected [i + 1 ]
168
172
if nb_filter <= 0 or filter_len <= 0 :
169
173
break
170
- ext += '.LC{} ={},{}' .format (l + 1 , nb_filter , filter_len )
174
+ ext += '.{}{} ={},{}' .format (name , l + 1 , nb_filter , filter_len )
171
175
if args .pool and layer_list [0 ] and layer_list [1 ]:
172
176
ext += '.P={}' .format (args .pool )
173
177
for i , n in enumerate (args .dense ):
@@ -296,7 +300,7 @@ def on_epoch_end(self, epoch, logs=None):
296
300
def main ():
297
301
parser = get_parser ()
298
302
args = parser .parse_args ()
299
- print ('Command line args = ' , args )
303
+ print ('Args: ' , args )
300
304
301
305
loggingLevel = logging .DEBUG if args .verbose else logging .INFO
302
306
logging .basicConfig (level = loggingLevel , format = '' )
@@ -324,7 +328,10 @@ def main():
324
328
filter_len = args .locally_connected [i + 1 ]
325
329
if nb_filter <= 0 or filter_len <= 0 :
326
330
break
327
- model .add (LocallyConnected1D (nb_filter , filter_len , input_shape = (datagen .input_dim , 1 ), activation = args .activation ))
331
+ if args .convolution :
332
+ model .add (Convolution1D (nb_filter , filter_len , input_shape = (datagen .input_dim , 1 ), activation = args .activation ))
333
+ else :
334
+ model .add (LocallyConnected1D (nb_filter , filter_len , input_shape = (datagen .input_dim , 1 ), activation = args .activation ))
328
335
if args .pool :
329
336
model .add (MaxPooling1D (pool_length = args .pool ))
330
337
model .add (Flatten ())
0 commit comments