Skip to content

Commit 80844ac

Browse files
authored
Merge pull request #1 from aria1th/patch-11
fix dropouts for future hypernetworks
2 parents b6a8bb1 + 029d7c7 commit 80844ac

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class HypernetworkModule(torch.nn.Module):
3434
}
3535
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
3636

37-
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False, activate_output=False):
37+
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
38+
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True):
3839
super().__init__()
3940

4041
assert layer_structure is not None, "layer_structure must not be None"
@@ -60,7 +61,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
6061
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
6162

6263
# Add dropout except last layer
63-
if use_dropout and i < len(layer_structure) - 3:
64+
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
6465
linears.append(torch.nn.Dropout(p=0.3))
6566

6667
self.linear = torch.nn.Sequential(*linears)
@@ -74,7 +75,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
7475
w, b = layer.weight.data, layer.bias.data
7576
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
7677
normal_(w, mean=0.0, std=0.01)
77-
normal_(b, mean=0.0, std=0.005)
78+
normal_(b, mean=0.0, std=0)
7879
elif weight_init == 'XavierUniform':
7980
xavier_uniform_(w)
8081
zeros_(b)
@@ -126,7 +127,7 @@ class Hypernetwork:
126127
filename = None
127128
name = None
128129

129-
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False):
130+
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
130131
self.filename = None
131132
self.name = name
132133
self.layers = {}
@@ -139,11 +140,14 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
139140
self.add_layer_norm = add_layer_norm
140141
self.use_dropout = use_dropout
141142
self.activate_output = activate_output
143+
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
142144

143145
for size in enable_sizes or []:
144146
self.layers[size] = (
145-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
146-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
147+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
148+
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
149+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
150+
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
147151
)
148152

149153
def weights(self):
@@ -172,7 +176,8 @@ def save(self, filename):
172176
state_dict['sd_checkpoint'] = self.sd_checkpoint
173177
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
174178
state_dict['activate_output'] = self.activate_output
175-
179+
state_dict['last_layer_dropout'] = self.last_layer_dropout
180+
176181
torch.save(state_dict, filename)
177182

178183
def load(self, filename):
@@ -193,12 +198,16 @@ def load(self, filename):
193198
self.use_dropout = state_dict.get('use_dropout', False)
194199
print(f"Dropout usage is set to {self.use_dropout}" )
195200
self.activate_output = state_dict.get('activate_output', True)
201+
print(f"Activate last layer is set to {self.activate_output}")
202+
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
196203

197204
for size, sd in state_dict.items():
198205
if type(size) == int:
199206
self.layers[size] = (
200-
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
201-
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
207+
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
208+
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
209+
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
210+
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
202211
)
203212

204213
self.name = state_dict.get('name', self.name)

modules/ui.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,8 +1238,8 @@ def create_ui(wrap_gradio_gpu_call):
12381238
new_hypernetwork_name = gr.Textbox(label="Name")
12391239
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
12401240
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
1241-
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
1242-
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
1241+
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
1242+
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Normal is default, for experiments, relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
12431243
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
12441244
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
12451245
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")

0 commit comments

Comments
 (0)