Skip to content

Commit 877d94f

Browse files
authored
Back compatibility
1 parent c702d4d commit 877d94f

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class HypernetworkModule(torch.nn.Module):
2828
"swish": torch.nn.Hardswish,
2929
}
3030

31-
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
31+
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False):
3232
super().__init__()
3333

3434
assert layer_structure is not None, "layer_structure must not be None"
@@ -42,7 +42,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
4242
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
4343

4444
# Add an activation func except last layer
45-
if activation_func == "linear" or activation_func is None or i >= len(layer_structure) - 2:
45+
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
4646
pass
4747
elif activation_func in self.activation_dict:
4848
linears.append(self.activation_dict[activation_func]())
@@ -105,7 +105,7 @@ class Hypernetwork:
105105
filename = None
106106
name = None
107107

108-
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
108+
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False):
109109
self.filename = None
110110
self.name = name
111111
self.layers = {}
@@ -116,11 +116,12 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
116116
self.activation_func = activation_func
117117
self.add_layer_norm = add_layer_norm
118118
self.use_dropout = use_dropout
119+
self.activate_output = activate_output
119120

120121
for size in enable_sizes or []:
121122
self.layers[size] = (
122-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
123-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
123+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
124+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
124125
)
125126

126127
def weights(self):
@@ -147,6 +148,7 @@ def save(self, filename):
147148
state_dict['use_dropout'] = self.use_dropout
148149
state_dict['sd_checkpoint'] = self.sd_checkpoint
149150
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
151+
state_dict['activate_output'] = self.activate_output
150152

151153
torch.save(state_dict, filename)
152154

@@ -161,12 +163,13 @@ def load(self, filename):
161163
self.activation_func = state_dict.get('activation_func', None)
162164
self.add_layer_norm = state_dict.get('is_layer_norm', False)
163165
self.use_dropout = state_dict.get('use_dropout', False)
166+
self.activate_output = state_dict.get('activate_output', True)
164167

165168
for size, sd in state_dict.items():
166169
if type(size) == int:
167170
self.layers[size] = (
168-
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
169-
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
171+
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
172+
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
170173
)
171174

172175
self.name = state_dict.get('name', self.name)

0 commit comments

Comments
 (0)