@@ -45,7 +45,7 @@ def on_epoch_end(self, epoch, logs):
4545 dpi = 120 )
4646
4747
48- def unet1d (window_shape = (128 ,), nb_filters_base = 32 , conv_kernel_init = 'he_normal' , prop_dropout_base = 0.1 , margin = 4 ):
48+ def unet1d (window_shape = (128 ,), nb_filters_base = 32 , conv_kernel_init = 'he_normal' , prop_dropout_base = 0.05 , margin = 4 ):
4949 """Builds and returns the UNet architecture using Keras.
5050 # Arguments
5151 window_shape: tuple of one integer defining the input/output window shape.
@@ -134,10 +134,12 @@ def conv_layer(nbf, x):
134134 x = conv_layer (nfb , x )
135135 x = conv_layer (nfb , x )
136136
137- # Apply the error margin before softmax activation.
138- x = MaxPooling1D (margin + 1 , strides = 1 , padding = 'same' )(x )
139- x = Conv1D ( 2 , 1 , activation = 'softmax' )(x )
137+ x = Conv1D ( 2 , 1 )( x )
138+ x = MaxPooling1D (margin + 1 , strides = 1 , padding = 'same' )(x )
139+ x = Activation ( 'softmax' )(x )
140140
141+ #x = Lambda(lambda x: x[:, :, 1:])(x)
142+ #x = MaxPooling1D(margin + 1, strides=1, padding='same')(x)
141143 x = Lambda (lambda x : x [:, :, - 1 ])(x )
142144 model = Model (inputs = inputs , outputs = x )
143145
@@ -205,7 +207,7 @@ def __init__(self, cpdir='%s/spikes_unet1d' % CHECKPOINTS_DIR,
205207
206208 def fit (self , dataset_paths , shape = (4096 ,), error_margin = 4. ,
207209 batch = 20 , nb_epochs = 20 , val_type = 'random_split' , prop_trn = 0.8 ,
208- prop_val = 0.2 , nb_folds = 5 , keras_callbacks = [], optimizer = Adam (0.001 )):
210+ prop_val = 0.2 , nb_folds = 5 , keras_callbacks = [], optimizer = Adam (0.002 )):
209211 """Constructs model based on parameters and trains with the given data.
210212 Internally, the function uses a local function to abstract the training
211213 for both validation types.
@@ -282,10 +284,7 @@ def loss(yt, yp):
282284 ModelCheckpoint ('%s/%d_model_val_F2_{val_F2:3f}_{epoch:03d}.hdf5' % cpt ,
283285 monitor = 'val_F2' , mode = 'max' , verbose = 1 , save_best_only = True ),
284286 CSVLogger ('%s/%d_metrics.csv' % cpt ),
285- MetricsPlotCallback ('%s/%d_metrics.png' % cpt ),
286- ReduceLROnPlateau (monitor = 'val_F2' , factor = 0.5 , min_lr = 0.0001 ,
287- patience = max (10 , int (nb_epochs * 0.2 )),
288- mode = 'max' , epsilon = 1e-2 , verbose = 1 )
287+ MetricsPlotCallback ('%s/%d_metrics.png' % cpt )
289288 ]
290289
291290 # Train.
0 commit comments