@@ -28,7 +28,7 @@ class HypernetworkModule(torch.nn.Module):
28
28
"swish" : torch .nn .Hardswish ,
29
29
}
30
30
31
- def __init__ (self , dim , state_dict = None , layer_structure = None , activation_func = None , add_layer_norm = False , use_dropout = False ):
31
+ def __init__ (self , dim , state_dict = None , layer_structure = None , activation_func = None , add_layer_norm = False , use_dropout = False , activate_output = False ):
32
32
super ().__init__ ()
33
33
34
34
assert layer_structure is not None , "layer_structure must not be None"
@@ -42,7 +42,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
42
42
linears .append (torch .nn .Linear (int (dim * layer_structure [i ]), int (dim * layer_structure [i + 1 ])))
43
43
44
44
# Add an activation func except last layer
45
- if activation_func == "linear" or activation_func is None or i >= len (layer_structure ) - 2 :
45
+ if activation_func == "linear" or activation_func is None or ( i >= len (layer_structure ) - 2 and not activate_output ) :
46
46
pass
47
47
elif activation_func in self .activation_dict :
48
48
linears .append (self .activation_dict [activation_func ]())
@@ -105,7 +105,7 @@ class Hypernetwork:
105
105
filename = None
106
106
name = None
107
107
108
- def __init__ (self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , add_layer_norm = False , use_dropout = False ):
108
+ def __init__ (self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , add_layer_norm = False , use_dropout = False , activate_output = False ):
109
109
self .filename = None
110
110
self .name = name
111
111
self .layers = {}
@@ -116,11 +116,12 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
116
116
self .activation_func = activation_func
117
117
self .add_layer_norm = add_layer_norm
118
118
self .use_dropout = use_dropout
119
+ self .activate_output = activate_output
119
120
120
121
for size in enable_sizes or []:
121
122
self .layers [size ] = (
122
- HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .add_layer_norm , self .use_dropout ),
123
- HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .add_layer_norm , self .use_dropout ),
123
+ HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .add_layer_norm , self .use_dropout , self . activate_output ),
124
+ HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .add_layer_norm , self .use_dropout , self . activate_output ),
124
125
)
125
126
126
127
def weights (self ):
@@ -147,6 +148,7 @@ def save(self, filename):
147
148
state_dict ['use_dropout' ] = self .use_dropout
148
149
state_dict ['sd_checkpoint' ] = self .sd_checkpoint
149
150
state_dict ['sd_checkpoint_name' ] = self .sd_checkpoint_name
151
+ state_dict ['activate_output' ] = self .activate_output
150
152
151
153
torch .save (state_dict , filename )
152
154
@@ -161,12 +163,13 @@ def load(self, filename):
161
163
self .activation_func = state_dict .get ('activation_func' , None )
162
164
self .add_layer_norm = state_dict .get ('is_layer_norm' , False )
163
165
self .use_dropout = state_dict .get ('use_dropout' , False )
166
+ self .activate_output = state_dict .get ('activate_output' , True )
164
167
165
168
for size , sd in state_dict .items ():
166
169
if type (size ) == int :
167
170
self .layers [size ] = (
168
- HypernetworkModule (size , sd [0 ], self .layer_structure , self .activation_func , self .add_layer_norm , self .use_dropout ),
169
- HypernetworkModule (size , sd [1 ], self .layer_structure , self .activation_func , self .add_layer_norm , self .use_dropout ),
171
+ HypernetworkModule (size , sd [0 ], self .layer_structure , self .activation_func , self .add_layer_norm , self .use_dropout , self . activate_output ),
172
+ HypernetworkModule (size , sd [1 ], self .layer_structure , self .activation_func , self .add_layer_norm , self .use_dropout , self . activate_output ),
170
173
)
171
174
172
175
self .name = state_dict .get ('name' , self .name )
0 commit comments