@@ -35,7 +35,8 @@ class HypernetworkModule(torch.nn.Module):
35
35
}
36
36
activation_dict .update ({cls_name .lower (): cls_obj for cls_name , cls_obj in inspect .getmembers (torch .nn .modules .activation ) if inspect .isclass (cls_obj ) and cls_obj .__module__ == 'torch.nn.modules.activation' })
37
37
38
- def __init__ (self , dim , state_dict = None , layer_structure = None , activation_func = None , weight_init = 'Normal' , add_layer_norm = False , use_dropout = False ):
38
+ def __init__ (self , dim , state_dict = None , layer_structure = None , activation_func = None , weight_init = 'Normal' ,
39
+ add_layer_norm = False , use_dropout = False , activate_output = False , last_layer_dropout = True ):
39
40
super ().__init__ ()
40
41
41
42
assert layer_structure is not None , "layer_structure must not be None"
@@ -48,8 +49,8 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
48
49
# Add a fully-connected layer
49
50
linears .append (torch .nn .Linear (int (dim * layer_structure [i ]), int (dim * layer_structure [i + 1 ])))
50
51
51
- # Add an activation func
52
- if activation_func == "linear" or activation_func is None :
52
+ # Add an activation func except last layer
53
+ if activation_func == "linear" or activation_func is None or ( i >= len ( layer_structure ) - 2 and not activate_output ) :
53
54
pass
54
55
elif activation_func in self .activation_dict :
55
56
linears .append (self .activation_dict [activation_func ]())
@@ -60,8 +61,8 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
60
61
if add_layer_norm :
61
62
linears .append (torch .nn .LayerNorm (int (dim * layer_structure [i + 1 ])))
62
63
63
- # Add dropout expect last layer
64
- if use_dropout and i < len (layer_structure ) - 3 :
64
+ # Add dropout except last layer
65
+ if use_dropout and ( i < len (layer_structure ) - 3 or last_layer_dropout and i < len ( layer_structure ) - 2 ) :
65
66
linears .append (torch .nn .Dropout (p = 0.3 ))
66
67
67
68
self .linear = torch .nn .Sequential (* linears )
@@ -75,7 +76,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
75
76
w , b = layer .weight .data , layer .bias .data
76
77
if weight_init == "Normal" or type (layer ) == torch .nn .LayerNorm :
77
78
normal_ (w , mean = 0.0 , std = 0.01 )
78
- normal_ (b , mean = 0.0 , std = 0.005 )
79
+ normal_ (b , mean = 0.0 , std = 0 )
79
80
elif weight_init == 'XavierUniform' :
80
81
xavier_uniform_ (w )
81
82
zeros_ (b )
@@ -127,7 +128,7 @@ class Hypernetwork:
127
128
filename = None
128
129
name = None
129
130
130
- def __init__ (self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , weight_init = None , add_layer_norm = False , use_dropout = False ):
131
+ def __init__ (self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , weight_init = None , add_layer_norm = False , use_dropout = False , activate_output = False , ** kwargs ):
131
132
self .filename = None
132
133
self .name = name
133
134
self .layers = {}
@@ -139,11 +140,15 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
139
140
self .weight_init = weight_init
140
141
self .add_layer_norm = add_layer_norm
141
142
self .use_dropout = use_dropout
143
+ self .activate_output = activate_output
144
+ self .last_layer_dropout = kwargs ['last_layer_dropout' ] if 'last_layer_dropout' in kwargs else True
142
145
143
146
for size in enable_sizes or []:
144
147
self .layers [size ] = (
145
- HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout ),
146
- HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout ),
148
+ HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init ,
149
+ self .add_layer_norm , self .use_dropout , self .activate_output , last_layer_dropout = self .last_layer_dropout ),
150
+ HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init ,
151
+ self .add_layer_norm , self .use_dropout , self .activate_output , last_layer_dropout = self .last_layer_dropout ),
147
152
)
148
153
149
154
def weights (self ):
@@ -171,7 +176,9 @@ def save(self, filename):
171
176
state_dict ['use_dropout' ] = self .use_dropout
172
177
state_dict ['sd_checkpoint' ] = self .sd_checkpoint
173
178
state_dict ['sd_checkpoint_name' ] = self .sd_checkpoint_name
174
-
179
+ state_dict ['activate_output' ] = self .activate_output
180
+ state_dict ['last_layer_dropout' ] = self .last_layer_dropout
181
+
175
182
torch .save (state_dict , filename )
176
183
177
184
def load (self , filename ):
@@ -191,12 +198,17 @@ def load(self, filename):
191
198
print (f"Layer norm is set to { self .add_layer_norm } " )
192
199
self .use_dropout = state_dict .get ('use_dropout' , False )
193
200
print (f"Dropout usage is set to { self .use_dropout } " )
201
+ self .activate_output = state_dict .get ('activate_output' , True )
202
+ print (f"Activate last layer is set to { self .activate_output } " )
203
+ self .last_layer_dropout = state_dict .get ('last_layer_dropout' , False )
194
204
195
205
for size , sd in state_dict .items ():
196
206
if type (size ) == int :
197
207
self .layers [size ] = (
198
- HypernetworkModule (size , sd [0 ], self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout ),
199
- HypernetworkModule (size , sd [1 ], self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout ),
208
+ HypernetworkModule (size , sd [0 ], self .layer_structure , self .activation_func , self .weight_init ,
209
+ self .add_layer_norm , self .use_dropout , self .activate_output , last_layer_dropout = self .last_layer_dropout ),
210
+ HypernetworkModule (size , sd [1 ], self .layer_structure , self .activation_func , self .weight_init ,
211
+ self .add_layer_norm , self .use_dropout , self .activate_output , last_layer_dropout = self .last_layer_dropout ),
200
212
)
201
213
202
214
self .name = state_dict .get ('name' , self .name )
0 commit comments