@@ -34,7 +34,8 @@ class HypernetworkModule(torch.nn.Module):
34
34
}
35
35
activation_dict .update ({cls_name .lower (): cls_obj for cls_name , cls_obj in inspect .getmembers (torch .nn .modules .activation ) if inspect .isclass (cls_obj ) and cls_obj .__module__ == 'torch.nn.modules.activation' })
36
36
37
- def __init__ (self , dim , state_dict = None , layer_structure = None , activation_func = None , weight_init = 'Normal' , add_layer_norm = False , use_dropout = False , activate_output = False ):
37
+ def __init__ (self , dim , state_dict = None , layer_structure = None , activation_func = None , weight_init = 'Normal' ,
38
+ add_layer_norm = False , use_dropout = False , activate_output = False , ** kwargs ):
38
39
super ().__init__ ()
39
40
40
41
assert layer_structure is not None , "layer_structure must not be None"
@@ -60,7 +61,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
60
61
linears .append (torch .nn .LayerNorm (int (dim * layer_structure [i + 1 ])))
61
62
62
63
# Add dropout except last layer
63
- if use_dropout and i < len (layer_structure ) - 3 :
64
+ if 'last_layer_dropout' in kwargs and kwargs [ 'last_layer_dropout' ] and use_dropout and i < len (layer_structure ) - 2 :
64
65
linears .append (torch .nn .Dropout (p = 0.3 ))
65
66
66
67
self .linear = torch .nn .Sequential (* linears )
@@ -126,7 +127,7 @@ class Hypernetwork:
126
127
filename = None
127
128
name = None
128
129
129
- def __init__ (self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , weight_init = None , add_layer_norm = False , use_dropout = False , activate_output = False ):
130
+ def __init__ (self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , weight_init = None , add_layer_norm = False , use_dropout = False , activate_output = False , ** kwargs ):
130
131
self .filename = None
131
132
self .name = name
132
133
self .layers = {}
@@ -139,11 +140,14 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
139
140
self .add_layer_norm = add_layer_norm
140
141
self .use_dropout = use_dropout
141
142
self .activate_output = activate_output
143
+ self .last_layer_dropout = kwargs ['last_layer_dropout' ] if 'last_layer_dropout' in kwargs else True
142
144
143
145
for size in enable_sizes or []:
144
146
self .layers [size ] = (
145
- HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout , self .activate_output ),
146
- HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout , self .activate_output ),
147
+ HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init ,
148
+ self .add_layer_norm , self .use_dropout , self .activate_output , last_layer_dropout = self .last_layer_dropout ),
149
+ HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init ,
150
+ self .add_layer_norm , self .use_dropout , self .activate_output , last_layer_dropout = self .last_layer_dropout ),
147
151
)
148
152
149
153
def weights (self ):
@@ -172,7 +176,8 @@ def save(self, filename):
172
176
state_dict ['sd_checkpoint' ] = self .sd_checkpoint
173
177
state_dict ['sd_checkpoint_name' ] = self .sd_checkpoint_name
174
178
state_dict ['activate_output' ] = self .activate_output
175
-
179
+ state_dict ['last_layer_dropout' ] = self .last_layer_dropout
180
+
176
181
torch .save (state_dict , filename )
177
182
178
183
def load (self , filename ):
@@ -193,12 +198,16 @@ def load(self, filename):
193
198
self .use_dropout = state_dict .get ('use_dropout' , False )
194
199
print (f"Dropout usage is set to { self .use_dropout } " )
195
200
self .activate_output = state_dict .get ('activate_output' , True )
201
+ print (f"Activate last layer is set to { self .activate_output } " )
202
+ self .last_layer_dropout = state_dict .get ('last_layer_dropout' , False )
196
203
197
204
for size , sd in state_dict .items ():
198
205
if type (size ) == int :
199
206
self .layers [size ] = (
200
- HypernetworkModule (size , sd [0 ], self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout , self .activate_output ),
201
- HypernetworkModule (size , sd [1 ], self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout , self .activate_output ),
207
+ HypernetworkModule (size , sd [0 ], self .layer_structure , self .activation_func , self .weight_init ,
208
+ self .add_layer_norm , self .use_dropout , self .activate_output , last_layer_dropout = self .last_layer_dropout ),
209
+ HypernetworkModule (size , sd [1 ], self .layer_structure , self .activation_func , self .weight_init ,
210
+ self .add_layer_norm , self .use_dropout , self .activate_output , last_layer_dropout = self .last_layer_dropout ),
202
211
)
203
212
204
213
self .name = state_dict .get ('name' , self .name )
0 commit comments