5
5
import os
6
6
import sys
7
7
import traceback
8
+ import inspect
8
9
9
10
import modules .textual_inversion .dataset
10
11
import torch
15
16
from modules .textual_inversion import textual_inversion
16
17
from modules .textual_inversion .learn_schedule import LearnRateScheduler
17
18
from torch import einsum
19
+ from torch .nn .init import normal_ , xavier_normal_ , xavier_uniform_ , kaiming_normal_ , kaiming_uniform_ , zeros_
18
20
19
21
from collections import defaultdict , deque
20
22
from statistics import stdev , mean
21
23
24
+
22
25
class HypernetworkModule (torch .nn .Module ):
23
26
multiplier = 1.0
24
27
activation_dict = {
25
28
"relu" : torch .nn .ReLU ,
26
29
"leakyrelu" : torch .nn .LeakyReLU ,
27
30
"elu" : torch .nn .ELU ,
28
31
"swish" : torch .nn .Hardswish ,
32
+ "tanh" : torch .nn .Tanh ,
33
+ "sigmoid" : torch .nn .Sigmoid ,
29
34
}
35
+ activation_dict .update ({cls_name .lower (): cls_obj for cls_name , cls_obj in inspect .getmembers (torch .nn .modules .activation ) if inspect .isclass (cls_obj ) and cls_obj .__module__ == 'torch.nn.modules.activation' })
30
36
31
- def __init__ (self , dim , state_dict = None , layer_structure = None , activation_func = None , add_layer_norm = False , use_dropout = False , activate_output = False ):
37
+ def __init__ (self , dim , state_dict = None , layer_structure = None , activation_func = None , weight_init = 'Normal' , add_layer_norm = False , use_dropout = False , activate_output = False ):
32
38
super ().__init__ ()
33
39
34
40
assert layer_structure is not None , "layer_structure must not be None"
@@ -65,9 +71,24 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
65
71
else :
66
72
for layer in self .linear :
67
73
if type (layer ) == torch .nn .Linear or type (layer ) == torch .nn .LayerNorm :
68
- layer .weight .data .normal_ (mean = 0.0 , std = 0.01 )
69
- layer .bias .data .zero_ ()
70
-
74
+ w , b = layer .weight .data , layer .bias .data
75
+ if weight_init == "Normal" or type (layer ) == torch .nn .LayerNorm :
76
+ normal_ (w , mean = 0.0 , std = 0.01 )
77
+ normal_ (b , mean = 0.0 , std = 0.005 )
78
+ elif weight_init == 'XavierUniform' :
79
+ xavier_uniform_ (w )
80
+ zeros_ (b )
81
+ elif weight_init == 'XavierNormal' :
82
+ xavier_normal_ (w )
83
+ zeros_ (b )
84
+ elif weight_init == 'KaimingUniform' :
85
+ kaiming_uniform_ (w , nonlinearity = 'leaky_relu' if 'leakyrelu' == activation_func else 'relu' )
86
+ zeros_ (b )
87
+ elif weight_init == 'KaimingNormal' :
88
+ kaiming_normal_ (w , nonlinearity = 'leaky_relu' if 'leakyrelu' == activation_func else 'relu' )
89
+ zeros_ (b )
90
+ else :
91
+ raise KeyError (f"Key { weight_init } is not defined as initialization!" )
71
92
self .to (devices .device )
72
93
73
94
def fix_old_state_dict (self , state_dict ):
@@ -105,7 +126,7 @@ class Hypernetwork:
105
126
filename = None
106
127
name = None
107
128
108
- def __init__ (self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , add_layer_norm = False , use_dropout = False , activate_output = False ):
129
+ def __init__ (self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , weight_init = None , add_layer_norm = False , use_dropout = False , activate_output = False )
109
130
self .filename = None
110
131
self .name = name
111
132
self .layers = {}
@@ -114,14 +135,15 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
114
135
self .sd_checkpoint_name = None
115
136
self .layer_structure = layer_structure
116
137
self .activation_func = activation_func
138
+ self .weight_init = weight_init
117
139
self .add_layer_norm = add_layer_norm
118
140
self .use_dropout = use_dropout
119
141
self .activate_output = activate_output
120
142
121
143
for size in enable_sizes or []:
122
144
self .layers [size ] = (
123
- HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .add_layer_norm , self .use_dropout , self .activate_output ),
124
- HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .add_layer_norm , self .use_dropout , self .activate_output ),
145
+ HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init , self . add_layer_norm , self .use_dropout , self .activate_output ),
146
+ HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init , self . add_layer_norm , self .use_dropout , self .activate_output ),
125
147
)
126
148
127
149
def weights (self ):
@@ -145,6 +167,7 @@ def save(self, filename):
145
167
state_dict ['layer_structure' ] = self .layer_structure
146
168
state_dict ['activation_func' ] = self .activation_func
147
169
state_dict ['is_layer_norm' ] = self .add_layer_norm
170
+ state_dict ['weight_initialization' ] = self .weight_init
148
171
state_dict ['use_dropout' ] = self .use_dropout
149
172
state_dict ['sd_checkpoint' ] = self .sd_checkpoint
150
173
state_dict ['sd_checkpoint_name' ] = self .sd_checkpoint_name
@@ -160,16 +183,22 @@ def load(self, filename):
160
183
state_dict = torch .load (filename , map_location = 'cpu' )
161
184
162
185
self .layer_structure = state_dict .get ('layer_structure' , [1 , 2 , 1 ])
186
+ print (self .layer_structure )
163
187
self .activation_func = state_dict .get ('activation_func' , None )
188
+ print (f"Activation function is { self .activation_func } " )
189
+ self .weight_init = state_dict .get ('weight_initialization' , 'Normal' )
190
+ print (f"Weight initialization is { self .weight_init } " )
164
191
self .add_layer_norm = state_dict .get ('is_layer_norm' , False )
165
- self .use_dropout = state_dict .get ('use_dropout' , False )
192
+ print (f"Layer norm is set to { self .add_layer_norm } " )
193
+ self .use_dropout = state_dict .get ('use_dropout' , False
194
+ print (f"Dropout usage is set to { self .use_dropout } " )
166
195
self .activate_output = state_dict .get ('activate_output' , True )
167
196
168
197
for size , sd in state_dict .items ():
169
198
if type (size ) == int :
170
199
self .layers [size ] = (
171
- HypernetworkModule (size , sd [0 ], self .layer_structure , self .activation_func , self .add_layer_norm , self .use_dropout , self .activate_output ),
172
- HypernetworkModule (size , sd [1 ], self .layer_structure , self .activation_func , self .add_layer_norm , self .use_dropout , self .activate_output ),
200
+ HypernetworkModule (size , sd [0 ], self .layer_structure , self .activation_func , self .weight_init , self . add_layer_norm , self .use_dropout , self .activate_output ),
201
+ HypernetworkModule (size , sd [1 ], self .layer_structure , self .activation_func , self .weight_init , self . add_layer_norm , self .use_dropout , self .activate_output ),
173
202
)
174
203
175
204
self .name = state_dict .get ('name' , self .name )
0 commit comments