5
5
from keras import optimizers
6
6
from keras import initializers
7
7
8
+ from keras .layers import Dropout
9
+ from keras .utils import get_custom_objects
8
10
from keras .metrics import binary_crossentropy , mean_squared_error
9
11
10
12
from scipy .stats .stats import pearsonr
@@ -38,11 +40,11 @@ def set_seed(seed):
38
40
39
41
def get_function (name ):
40
42
mapping = {}
41
-
43
+
42
44
mapped = mapping .get (name )
43
45
if not mapped :
44
46
raise Exception ('No keras function found for "{}"' .format (name ))
45
-
47
+
46
48
return mapped
47
49
48
50
@@ -55,7 +57,7 @@ def build_optimizer(type, lr, kerasDefaults):
55
57
nesterov = kerasDefaults ['nesterov_sgd' ])#,
56
58
#clipnorm=kerasDefaults['clipnorm'],
57
59
#clipvalue=kerasDefaults['clipvalue'])
58
-
60
+
59
61
elif type == 'rmsprop' :
60
62
return optimizers .RMSprop (lr = lr , rho = kerasDefaults ['rho' ],
61
63
epsilon = kerasDefaults ['epsilon' ],
@@ -101,10 +103,10 @@ def build_optimizer(type, lr, kerasDefaults):
101
103
102
104
103
105
def build_initializer (type , kerasDefaults , seed = None , constant = 0. ):
104
-
106
+
105
107
if type == 'constant' :
106
108
return initializers .Constant (value = constant )
107
-
109
+
108
110
elif type == 'uniform' :
109
111
return initializers .RandomUniform (minval = kerasDefaults ['minval_uniform' ],
110
112
maxval = kerasDefaults ['maxval_uniform' ],
@@ -155,3 +157,18 @@ def evaluate_autoencoder(y_pred, y_test):
155
157
# print('Mean squared error: {}%'.format(mse))
156
158
return {'mse' : mse , 'r2_score' : r2 , 'correlation' : corr }
157
159
160
+
161
+ class PermanentDropout (Dropout ):
162
+ def __init__ (self , rate , ** kwargs ):
163
+ super (PermanentDropout , self ).__init__ (rate , ** kwargs )
164
+ self .uses_learning_phase = False
165
+
166
+ def call (self , x , mask = None ):
167
+ if 0. < self .rate < 1. :
168
+ noise_shape = self ._get_noise_shape (x )
169
+ x = K .dropout (x , self .rate , noise_shape )
170
+ return x
171
+
172
+
173
+ def register_permanent_dropout ():
174
+ get_custom_objects ()['PermanentDropout' ] = PermanentDropout
0 commit comments