Skip to content

Commit c486dd8

Browse files
committed
add convolution
1 parent 2a9e4d0 commit c486dd8

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

P1B3/p1b3_baseline.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from keras import backend as K
1818
from keras import metrics
1919
from keras.models import Sequential
20-
from keras.layers import Dense, Dropout, LocallyConnected1D, MaxPooling1D, Flatten
20+
from keras.layers import Dense, Dropout, LocallyConnected1D, Convolution1D, MaxPooling1D, Flatten
2121
from keras.callbacks import Callback, ModelCheckpoint, ProgbarLogger
2222

2323
from sklearn.preprocessing import Imputer
@@ -65,7 +65,7 @@
6565
DENSE_LAYERS = [D1, D2, D3, D4]
6666

6767
# Number of units per locally connected layer
68-
LC1 = 10, 1 # nb_filter, filter_length
68+
LC1 = 10, 10 # nb_filter, filter_length
6969
LC2 = 0, 0 # disabled layer
7070
# LOCALLY_CONNECTED_LAYERS = list(LC1 + LC2)
7171
LOCALLY_CONNECTED_LAYERS = [0, 0]
@@ -91,6 +91,9 @@ def get_parser():
9191
parser.add_argument("-b", "--batch_size", action="store",
9292
default=BATCH_SIZE, type=int,
9393
help="batch size")
94+
parser.add_argument("-c", "--convolution", action="store_true",
95+
default=False,
96+
help="use convolution layers instead of locally connected layers")
9497
parser.add_argument("-d", "--dense", action="store", nargs='+', type=int,
9598
default=DENSE_LAYERS,
9699
help="number of units in fully connected layers in an integer array")
@@ -161,13 +164,14 @@ def extension_from_parameters(args):
161164
if args.feature_subsample:
162165
ext += '.F={}'.format(args.feature_subsample)
163166
if args.locally_connected:
167+
name = 'C' if args.convolution else 'LC'
164168
layer_list = list(range(0, len(args.locally_connected), 2))
165169
for l, i in enumerate(layer_list):
166170
nb_filter = args.locally_connected[i]
167171
filter_len = args.locally_connected[i+1]
168172
if nb_filter <= 0 or filter_len <= 0:
169173
break
170-
ext += '.LC{}={},{}'.format(l+1, nb_filter, filter_len)
174+
ext += '.{}{}={},{}'.format(name, l+1, nb_filter, filter_len)
171175
if args.pool and layer_list[0] and layer_list[1]:
172176
ext += '.P={}'.format(args.pool)
173177
for i, n in enumerate(args.dense):
@@ -296,7 +300,7 @@ def on_epoch_end(self, epoch, logs=None):
296300
def main():
297301
parser = get_parser()
298302
args = parser.parse_args()
299-
print('Command line args =', args)
303+
print('Args:', args)
300304

301305
loggingLevel = logging.DEBUG if args.verbose else logging.INFO
302306
logging.basicConfig(level=loggingLevel, format='')
@@ -324,7 +328,10 @@ def main():
324328
filter_len = args.locally_connected[i+1]
325329
if nb_filter <= 0 or filter_len <= 0:
326330
break
327-
model.add(LocallyConnected1D(nb_filter, filter_len, input_shape=(datagen.input_dim, 1), activation=args.activation))
331+
if args.convolution:
332+
model.add(Convolution1D(nb_filter, filter_len, input_shape=(datagen.input_dim, 1), activation=args.activation))
333+
else:
334+
model.add(LocallyConnected1D(nb_filter, filter_len, input_shape=(datagen.input_dim, 1), activation=args.activation))
328335
if args.pool:
329336
model.add(MaxPooling1D(pool_length=args.pool))
330337
model.add(Flatten())

0 commit comments

Comments
 (0)