Skip to content

Commit 26108a7

Browse files
Merge pull request #3698 from guaneec/hn-activation
Remove activation from final layer of Hypernetworks
2 parents 2cf3d2a + 4918eb6 commit 26108a7

File tree

2 files changed

+26
-14
lines changed

2 files changed

+26
-14
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class HypernetworkModule(torch.nn.Module):
3535
}
3636
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
3737

38-
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
38+
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
39+
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True):
3940
super().__init__()
4041

4142
assert layer_structure is not None, "layer_structure must not be None"
@@ -48,8 +49,8 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
4849
# Add a fully-connected layer
4950
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
5051

51-
# Add an activation func
52-
if activation_func == "linear" or activation_func is None:
52+
# Add an activation func except last layer
53+
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
5354
pass
5455
elif activation_func in self.activation_dict:
5556
linears.append(self.activation_dict[activation_func]())
@@ -60,8 +61,8 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
6061
if add_layer_norm:
6162
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
6263

63-
# Add dropout expect last layer
64-
if use_dropout and i < len(layer_structure) - 3:
64+
# Add dropout except last layer
65+
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
6566
linears.append(torch.nn.Dropout(p=0.3))
6667

6768
self.linear = torch.nn.Sequential(*linears)
@@ -75,7 +76,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
7576
w, b = layer.weight.data, layer.bias.data
7677
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
7778
normal_(w, mean=0.0, std=0.01)
78-
normal_(b, mean=0.0, std=0.005)
79+
normal_(b, mean=0.0, std=0)
7980
elif weight_init == 'XavierUniform':
8081
xavier_uniform_(w)
8182
zeros_(b)
@@ -127,7 +128,7 @@ class Hypernetwork:
127128
filename = None
128129
name = None
129130

130-
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
131+
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
131132
self.filename = None
132133
self.name = name
133134
self.layers = {}
@@ -139,11 +140,15 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
139140
self.weight_init = weight_init
140141
self.add_layer_norm = add_layer_norm
141142
self.use_dropout = use_dropout
143+
self.activate_output = activate_output
144+
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
142145

143146
for size in enable_sizes or []:
144147
self.layers[size] = (
145-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
146-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
148+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
149+
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
150+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
151+
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
147152
)
148153

149154
def weights(self):
@@ -171,7 +176,9 @@ def save(self, filename):
171176
state_dict['use_dropout'] = self.use_dropout
172177
state_dict['sd_checkpoint'] = self.sd_checkpoint
173178
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
174-
179+
state_dict['activate_output'] = self.activate_output
180+
state_dict['last_layer_dropout'] = self.last_layer_dropout
181+
175182
torch.save(state_dict, filename)
176183

177184
def load(self, filename):
@@ -191,12 +198,17 @@ def load(self, filename):
191198
print(f"Layer norm is set to {self.add_layer_norm}")
192199
self.use_dropout = state_dict.get('use_dropout', False)
193200
print(f"Dropout usage is set to {self.use_dropout}" )
201+
self.activate_output = state_dict.get('activate_output', True)
202+
print(f"Activate last layer is set to {self.activate_output}")
203+
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
194204

195205
for size, sd in state_dict.items():
196206
if type(size) == int:
197207
self.layers[size] = (
198-
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
199-
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
208+
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
209+
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
210+
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
211+
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
200212
)
201213

202214
self.name = state_dict.get('name', self.name)

modules/ui.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,8 +1182,8 @@ def create_ui(wrap_gradio_gpu_call):
11821182
new_hypernetwork_name = gr.Textbox(label="Name")
11831183
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
11841184
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
1185-
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
1186-
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
1185+
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
1186+
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
11871187
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
11881188
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
11891189
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")

0 commit comments

Comments
 (0)