22
22
from statistics import stdev , mean
23
23
24
24
25
+ optimizer_dict = {optim_name : cls_obj for optim_name , cls_obj in inspect .getmembers (torch .optim , inspect .isclass ) if optim_name != "Optimizer" }
26
+
25
27
class HypernetworkModule (torch .nn .Module ):
26
28
multiplier = 1.0
27
29
activation_dict = {
@@ -142,6 +144,8 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
142
144
self .use_dropout = use_dropout
143
145
self .activate_output = activate_output
144
146
self .last_layer_dropout = kwargs ['last_layer_dropout' ] if 'last_layer_dropout' in kwargs else True
147
+ self .optimizer_name = None
148
+ self .optimizer_state_dict = None
145
149
146
150
for size in enable_sizes or []:
147
151
self .layers [size ] = (
@@ -163,6 +167,7 @@ def weights(self):
163
167
164
168
def save (self , filename ):
165
169
state_dict = {}
170
+ optimizer_saved_dict = {}
166
171
167
172
for k , v in self .layers .items ():
168
173
state_dict [k ] = (v [0 ].state_dict (), v [1 ].state_dict ())
@@ -178,8 +183,15 @@ def save(self, filename):
178
183
state_dict ['sd_checkpoint_name' ] = self .sd_checkpoint_name
179
184
state_dict ['activate_output' ] = self .activate_output
180
185
state_dict ['last_layer_dropout' ] = self .last_layer_dropout
181
-
186
+
187
+ if self .optimizer_name is not None :
188
+ optimizer_saved_dict ['optimizer_name' ] = self .optimizer_name
189
+
182
190
torch .save (state_dict , filename )
191
+ if shared .opts .save_optimizer_state and self .optimizer_state_dict :
192
+ optimizer_saved_dict ['hash' ] = sd_models .model_hash (filename )
193
+ optimizer_saved_dict ['optimizer_state_dict' ] = self .optimizer_state_dict
194
+ torch .save (optimizer_saved_dict , filename + '.optim' )
183
195
184
196
def load (self , filename ):
185
197
self .filename = filename
@@ -202,6 +214,18 @@ def load(self, filename):
202
214
print (f"Activate last layer is set to { self .activate_output } " )
203
215
self .last_layer_dropout = state_dict .get ('last_layer_dropout' , False )
204
216
217
+ optimizer_saved_dict = torch .load (self .filename + '.optim' , map_location = 'cpu' ) if os .path .exists (self .filename + '.optim' ) else {}
218
+ self .optimizer_name = optimizer_saved_dict .get ('optimizer_name' , 'AdamW' )
219
+ print (f"Optimizer name is { self .optimizer_name } " )
220
+ if sd_models .model_hash (filename ) == optimizer_saved_dict .get ('hash' , None ):
221
+ self .optimizer_state_dict = optimizer_saved_dict .get ('optimizer_state_dict' , None )
222
+ else :
223
+ self .optimizer_state_dict = None
224
+ if self .optimizer_state_dict :
225
+ print ("Loaded existing optimizer from checkpoint" )
226
+ else :
227
+ print ("No saved optimizer exists in checkpoint" )
228
+
205
229
for size , sd in state_dict .items ():
206
230
if type (size ) == int :
207
231
self .layers [size ] = (
@@ -223,7 +247,7 @@ def list_hypernetworks(path):
223
247
name = os .path .splitext (os .path .basename (filename ))[0 ]
224
248
# Prevent a hypothetical "None.pt" from being listed.
225
249
if name != "None" :
226
- res [name ] = filename
250
+ res [name + f"( { sd_models . model_hash ( filename ) } )" ] = filename
227
251
return res
228
252
229
253
@@ -358,6 +382,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
358
382
shared .state .textinfo = "Initializing hypernetwork training..."
359
383
shared .state .job_count = steps
360
384
385
+ hypernetwork_name = hypernetwork_name .rsplit ('(' , 1 )[0 ]
361
386
filename = os .path .join (shared .cmd_opts .hypernetwork_dir , f'{ hypernetwork_name } .pt' )
362
387
363
388
log_directory = os .path .join (log_directory , datetime .datetime .now ().strftime ("%Y-%m-%d" ), hypernetwork_name )
@@ -404,8 +429,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
404
429
weights = hypernetwork .weights ()
405
430
for weight in weights :
406
431
weight .requires_grad = True
407
- # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
408
- optimizer = torch .optim .AdamW (weights , lr = scheduler .learn_rate )
432
+ # Here we use optimizer from saved HN, or we can specify as UI option.
433
+ if (optimizer_name := hypernetwork .optimizer_name ) in optimizer_dict :
434
+ optimizer = optimizer_dict [hypernetwork .optimizer_name ](params = weights , lr = scheduler .learn_rate )
435
+ else :
436
+ print (f"Optimizer type { optimizer_name } is not defined!" )
437
+ optimizer = torch .optim .AdamW (params = weights , lr = scheduler .learn_rate )
438
+ optimizer_name = 'AdamW'
439
+ if hypernetwork .optimizer_state_dict : # This line must be changed if Optimizer type can be different from saved optimizer.
440
+ try :
441
+ optimizer .load_state_dict (hypernetwork .optimizer_state_dict )
442
+ except RuntimeError as e :
443
+ print ("Cannot resume from saved optimizer!" )
444
+ print (e )
409
445
410
446
steps_without_grad = 0
411
447
@@ -467,7 +503,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
467
503
# Before saving, change name to match current checkpoint.
468
504
hypernetwork_name_every = f'{ hypernetwork_name } -{ steps_done } '
469
505
last_saved_file = os .path .join (hypernetwork_dir , f'{ hypernetwork_name_every } .pt' )
506
+ hypernetwork .optimizer_name = optimizer_name
507
+ if shared .opts .save_optimizer_state :
508
+ hypernetwork .optimizer_state_dict = optimizer .state_dict ()
470
509
save_hypernetwork (hypernetwork , checkpoint , hypernetwork_name , last_saved_file )
510
+ hypernetwork .optimizer_state_dict = None # dereference it after saving, to save memory.
471
511
472
512
textual_inversion .write_loss (log_directory , "hypernetwork_loss.csv" , hypernetwork .step , len (ds ), {
473
513
"loss" : f"{ previous_mean_loss :.7f} " ,
@@ -530,8 +570,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
530
570
report_statistics (loss_dict )
531
571
532
572
filename = os .path .join (shared .cmd_opts .hypernetwork_dir , f'{ hypernetwork_name } .pt' )
573
+ hypernetwork .optimizer_name = optimizer_name
574
+ if shared .opts .save_optimizer_state :
575
+ hypernetwork .optimizer_state_dict = optimizer .state_dict ()
533
576
save_hypernetwork (hypernetwork , checkpoint , hypernetwork_name , filename )
534
-
577
+ del optimizer
578
+ hypernetwork .optimizer_state_dict = None # dereference it after saving, to save memory.
535
579
return hypernetwork , filename
536
580
537
581
def save_hypernetwork (hypernetwork , checkpoint , hypernetwork_name , filename ):
0 commit comments