Skip to content

Commit 85fcccc

Browse files
committed
Squashed commit of fixing dropout silently
fix dropouts for future hypernetworks add kwargs for Hypernetwork class hypernet UI for gradio input add recommended options remove as options revert adding options in ui
1 parent b6a8bb1 commit 85fcccc

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class HypernetworkModule(torch.nn.Module):
3434
}
3535
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
3636

37-
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False, activate_output=False):
37+
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
38+
add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
3839
super().__init__()
3940

4041
assert layer_structure is not None, "layer_structure must not be None"
@@ -60,7 +61,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
6061
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
6162

6263
# Add dropout except last layer
63-
if use_dropout and i < len(layer_structure) - 3:
64+
if 'last_layer_dropout' in kwargs and kwargs['last_layer_dropout'] and use_dropout and i < len(layer_structure) - 2:
6465
linears.append(torch.nn.Dropout(p=0.3))
6566

6667
self.linear = torch.nn.Sequential(*linears)
@@ -126,7 +127,7 @@ class Hypernetwork:
126127
filename = None
127128
name = None
128129

129-
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False):
130+
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
130131
self.filename = None
131132
self.name = name
132133
self.layers = {}
@@ -139,11 +140,14 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
139140
self.add_layer_norm = add_layer_norm
140141
self.use_dropout = use_dropout
141142
self.activate_output = activate_output
143+
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
142144

143145
for size in enable_sizes or []:
144146
self.layers[size] = (
145-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
146-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
147+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
148+
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
149+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
150+
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
147151
)
148152

149153
def weights(self):
@@ -172,7 +176,8 @@ def save(self, filename):
172176
state_dict['sd_checkpoint'] = self.sd_checkpoint
173177
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
174178
state_dict['activate_output'] = self.activate_output
175-
179+
state_dict['last_layer_dropout'] = self.last_layer_dropout
180+
176181
torch.save(state_dict, filename)
177182

178183
def load(self, filename):
@@ -193,12 +198,16 @@ def load(self, filename):
193198
self.use_dropout = state_dict.get('use_dropout', False)
194199
print(f"Dropout usage is set to {self.use_dropout}" )
195200
self.activate_output = state_dict.get('activate_output', True)
201+
print(f"Activate last layer is set to {self.activate_output}")
202+
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
196203

197204
for size, sd in state_dict.items():
198205
if type(size) == int:
199206
self.layers[size] = (
200-
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
201-
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
207+
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
208+
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
209+
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
210+
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
202211
)
203212

204213
self.name = state_dict.get('name', self.name)

modules/ui.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,8 +1238,8 @@ def create_ui(wrap_gradio_gpu_call):
12381238
new_hypernetwork_name = gr.Textbox(label="Name")
12391239
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
12401240
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
1241-
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
1242-
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
1241+
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
1242+
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Normal is default, for experiments, relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
12431243
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
12441244
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
12451245
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")

0 commit comments

Comments
 (0)