Skip to content

Commit de096d0

Browse files
aria1thAUTOMATIC1111
authored andcommitted
Weight initialization and More activation func
add weight init add weight init option in create_hypernetwork fstringify hypernet info save weight initialization info for further debugging fill bias with zero for He/Xavier initialize LayerNorm with Normal fix loading weight_init
1 parent 3e15f8e commit de096d0

File tree

3 files changed

+44
-11
lines changed

3 files changed

+44
-11
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import sys
77
import traceback
8+
import inspect
89

910
import modules.textual_inversion.dataset
1011
import torch
@@ -15,20 +16,25 @@
1516
from modules.textual_inversion import textual_inversion
1617
from modules.textual_inversion.learn_schedule import LearnRateScheduler
1718
from torch import einsum
19+
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
1820

1921
from collections import defaultdict, deque
2022
from statistics import stdev, mean
2123

24+
2225
class HypernetworkModule(torch.nn.Module):
2326
multiplier = 1.0
2427
activation_dict = {
2528
"relu": torch.nn.ReLU,
2629
"leakyrelu": torch.nn.LeakyReLU,
2730
"elu": torch.nn.ELU,
2831
"swish": torch.nn.Hardswish,
32+
"tanh": torch.nn.Tanh,
33+
"sigmoid": torch.nn.Sigmoid,
2934
}
35+
activation_dict.update({cls_name: cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
3036

31-
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
37+
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
3238
super().__init__()
3339

3440
assert layer_structure is not None, "layer_structure must not be None"
@@ -65,9 +71,24 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
6571
else:
6672
for layer in self.linear:
6773
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
68-
layer.weight.data.normal_(mean=0.0, std=0.01)
69-
layer.bias.data.zero_()
70-
74+
w, b = layer.weight.data, layer.bias.data
75+
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
76+
normal_(w, mean=0.0, std=0.01)
77+
normal_(b, mean=0.0, std=0.005)
78+
elif weight_init == 'XavierUniform':
79+
xavier_uniform_(w)
80+
zeros_(b)
81+
elif weight_init == 'XavierNormal':
82+
xavier_normal_(w)
83+
zeros_(b)
84+
elif weight_init == 'KaimingUniform':
85+
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
86+
zeros_(b)
87+
elif weight_init == 'KaimingNormal':
88+
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
89+
zeros_(b)
90+
else:
91+
raise KeyError(f"Key {weight_init} is not defined as initialization!")
7192
self.to(devices.device)
7293

7394
def fix_old_state_dict(self, state_dict):
@@ -105,7 +126,7 @@ class Hypernetwork:
105126
filename = None
106127
name = None
107128

108-
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
129+
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
109130
self.filename = None
110131
self.name = name
111132
self.layers = {}
@@ -114,13 +135,14 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
114135
self.sd_checkpoint_name = None
115136
self.layer_structure = layer_structure
116137
self.activation_func = activation_func
138+
self.weight_init = weight_init
117139
self.add_layer_norm = add_layer_norm
118140
self.use_dropout = use_dropout
119141

120142
for size in enable_sizes or []:
121143
self.layers[size] = (
122-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
123-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
144+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
145+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
124146
)
125147

126148
def weights(self):
@@ -144,6 +166,7 @@ def save(self, filename):
144166
state_dict['layer_structure'] = self.layer_structure
145167
state_dict['activation_func'] = self.activation_func
146168
state_dict['is_layer_norm'] = self.add_layer_norm
169+
state_dict['weight_initialization'] = self.weight_init
147170
state_dict['use_dropout'] = self.use_dropout
148171
state_dict['sd_checkpoint'] = self.sd_checkpoint
149172
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -158,15 +181,21 @@ def load(self, filename):
158181
state_dict = torch.load(filename, map_location='cpu')
159182

160183
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
184+
print(self.layer_structure)
161185
self.activation_func = state_dict.get('activation_func', None)
186+
print(f"Activation function is {self.activation_func}")
187+
self.weight_init = state_dict.get('weight_initialization', 'Normal')
188+
print(f"Weight initialization is {self.weight_init}")
162189
self.add_layer_norm = state_dict.get('is_layer_norm', False)
190+
print(f"Layer norm is set to {self.add_layer_norm}")
163191
self.use_dropout = state_dict.get('use_dropout', False)
192+
print(f"Dropout usage is set to {self.use_dropout}" )
164193

165194
for size, sd in state_dict.items():
166195
if type(size) == int:
167196
self.layers[size] = (
168-
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
169-
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
197+
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
198+
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
170199
)
171200

172201
self.name = state_dict.get('name', self.name)

modules/hypernetworks/ui.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from modules import devices, sd_hijack, shared
99
from modules.hypernetworks import hypernetwork
1010

11+
keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
1112

12-
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
13+
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
1314
# Remove illegal characters from name.
1415
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
1516

@@ -25,6 +26,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
2526
enable_sizes=[int(x) for x in enable_sizes],
2627
layer_structure=layer_structure,
2728
activation_func=activation_func,
29+
weight_init=weight_init,
2830
add_layer_norm=add_layer_norm,
2931
use_dropout=use_dropout,
3032
)

modules/ui.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,8 @@ def create_ui(wrap_gradio_gpu_call):
12381238
new_hypernetwork_name = gr.Textbox(label="Name")
12391239
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
12401240
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
1241-
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
1241+
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
1242+
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
12421243
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
12431244
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
12441245
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
@@ -1342,6 +1343,7 @@ def create_ui(wrap_gradio_gpu_call):
13421343
overwrite_old_hypernetwork,
13431344
new_hypernetwork_layer_structure,
13441345
new_hypernetwork_activation_func,
1346+
new_hypernetwork_initialization_option,
13451347
new_hypernetwork_add_layer_norm,
13461348
new_hypernetwork_use_dropout
13471349
],

0 commit comments

Comments
 (0)