Skip to content

Commit 6c29468

Browse files
authored
Merge pull request #27 from levinas/release_01
add PermanentDropout
2 parents 8346298 + f7d095a commit 6c29468

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

common/keras_utils.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from keras import optimizers
66
from keras import initializers
77

8+
from keras.layers import Dropout
9+
from keras.utils import get_custom_objects
810
from keras.metrics import binary_crossentropy, mean_squared_error
911

1012
from scipy.stats.stats import pearsonr
@@ -38,11 +40,11 @@ def set_seed(seed):
3840

3941
def get_function(name):
4042
mapping = {}
41-
43+
4244
mapped = mapping.get(name)
4345
if not mapped:
4446
raise Exception('No keras function found for "{}"'.format(name))
45-
47+
4648
return mapped
4749

4850

@@ -55,7 +57,7 @@ def build_optimizer(type, lr, kerasDefaults):
5557
nesterov=kerasDefaults['nesterov_sgd'])#,
5658
#clipnorm=kerasDefaults['clipnorm'],
5759
#clipvalue=kerasDefaults['clipvalue'])
58-
60+
5961
elif type == 'rmsprop':
6062
return optimizers.RMSprop(lr=lr, rho=kerasDefaults['rho'],
6163
epsilon=kerasDefaults['epsilon'],
@@ -101,10 +103,10 @@ def build_optimizer(type, lr, kerasDefaults):
101103

102104

103105
def build_initializer(type, kerasDefaults, seed=None, constant=0.):
104-
106+
105107
if type == 'constant':
106108
return initializers.Constant(value=constant)
107-
109+
108110
elif type == 'uniform':
109111
return initializers.RandomUniform(minval=kerasDefaults['minval_uniform'],
110112
maxval=kerasDefaults['maxval_uniform'],
@@ -155,3 +157,18 @@ def evaluate_autoencoder(y_pred, y_test):
155157
# print('Mean squared error: {}%'.format(mse))
156158
return {'mse': mse, 'r2_score': r2, 'correlation': corr}
157159

160+
161+
class PermanentDropout(Dropout):
162+
def __init__(self, rate, **kwargs):
163+
super(PermanentDropout, self).__init__(rate, **kwargs)
164+
self.uses_learning_phase = False
165+
166+
def call(self, x, mask=None):
167+
if 0. < self.rate < 1.:
168+
noise_shape = self._get_noise_shape(x)
169+
x = K.dropout(x, self.rate, noise_shape)
170+
return x
171+
172+
173+
def register_permanent_dropout():
174+
get_custom_objects()['PermanentDropout'] = PermanentDropout

0 commit comments

Comments
 (0)