Skip to content

Commit 4e8337b

Browse files
committed
Increase CV validation F2 to 0.79.
- Rearranged error margin maxpooling in the network. - Decreased dropout. - Removed learning rate reduction.
1 parent 45c0384 commit 4e8337b

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

deepcalcium/models/spikes/unet_1d_segmentation.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def on_epoch_end(self, epoch, logs):
4545
dpi=120)
4646

4747

48-
def unet1d(window_shape=(128,), nb_filters_base=32, conv_kernel_init='he_normal', prop_dropout_base=0.1, margin=4):
48+
def unet1d(window_shape=(128,), nb_filters_base=32, conv_kernel_init='he_normal', prop_dropout_base=0.05, margin=4):
4949
"""Builds and returns the UNet architecture using Keras.
5050
# Arguments
5151
window_shape: tuple of one integer defining the input/output window shape.
@@ -134,10 +134,12 @@ def conv_layer(nbf, x):
134134
x = conv_layer(nfb, x)
135135
x = conv_layer(nfb, x)
136136

137-
# Apply the error margin before softmax activation.
138-
x = MaxPooling1D(margin + 1, strides=1, padding='same')(x)
139-
x = Conv1D(2, 1, activation='softmax')(x)
137+
x = Conv1D(2, 1)(x)
138+
x = MaxPooling1D(margin+1, strides=1, padding='same')(x)
139+
x = Activation('softmax')(x)
140140

141+
#x = Lambda(lambda x: x[:, :, 1:])(x)
142+
#x = MaxPooling1D(margin + 1, strides=1, padding='same')(x)
141143
x = Lambda(lambda x: x[:, :, -1])(x)
142144
model = Model(inputs=inputs, outputs=x)
143145

@@ -205,7 +207,7 @@ def __init__(self, cpdir='%s/spikes_unet1d' % CHECKPOINTS_DIR,
205207

206208
def fit(self, dataset_paths, shape=(4096,), error_margin=4.,
207209
batch=20, nb_epochs=20, val_type='random_split', prop_trn=0.8,
208-
prop_val=0.2, nb_folds=5, keras_callbacks=[], optimizer=Adam(0.001)):
210+
prop_val=0.2, nb_folds=5, keras_callbacks=[], optimizer=Adam(0.002)):
209211
"""Constructs model based on parameters and trains with the given data.
210212
Internally, the function uses a local function to abstract the training
211213
for both validation types.
@@ -282,10 +284,7 @@ def loss(yt, yp):
282284
ModelCheckpoint('%s/%d_model_val_F2_{val_F2:3f}_{epoch:03d}.hdf5' % cpt,
283285
monitor='val_F2', mode='max', verbose=1, save_best_only=True),
284286
CSVLogger('%s/%d_metrics.csv' % cpt),
285-
MetricsPlotCallback('%s/%d_metrics.png' % cpt),
286-
ReduceLROnPlateau(monitor='val_F2', factor=0.5, min_lr=0.0001,
287-
patience=max(10, int(nb_epochs * 0.2)),
288-
mode='max', epsilon=1e-2, verbose=1)
287+
MetricsPlotCallback('%s/%d_metrics.png' % cpt)
289288
]
290289

291290
# Train.

0 commit comments

Comments
 (0)