22
22
from statistics import stdev , mean
23
23
24
24
25
+ optimizer_dict = {optim_name : cls_obj for optim_name , cls_obj in inspect .getmembers (torch .optim , inspect .isclass ) if optim_name != "Optimizer" }
26
+
25
27
class HypernetworkModule (torch .nn .Module ):
26
28
multiplier = 1.0
27
29
activation_dict = {
@@ -142,6 +144,8 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
142
144
self .use_dropout = use_dropout
143
145
self .activate_output = activate_output
144
146
self .last_layer_dropout = kwargs ['last_layer_dropout' ] if 'last_layer_dropout' in kwargs else True
147
+ self .optimizer_name = None
148
+ self .optimizer_state_dict = None
145
149
146
150
for size in enable_sizes or []:
147
151
self .layers [size ] = (
@@ -163,6 +167,7 @@ def weights(self):
163
167
164
168
def save (self , filename ):
165
169
state_dict = {}
170
+ optimizer_saved_dict = {}
166
171
167
172
for k , v in self .layers .items ():
168
173
state_dict [k ] = (v [0 ].state_dict (), v [1 ].state_dict ())
@@ -178,8 +183,15 @@ def save(self, filename):
178
183
state_dict ['sd_checkpoint_name' ] = self .sd_checkpoint_name
179
184
state_dict ['activate_output' ] = self .activate_output
180
185
state_dict ['last_layer_dropout' ] = self .last_layer_dropout
181
-
186
+
187
+ if self .optimizer_name is not None :
188
+ optimizer_saved_dict ['optimizer_name' ] = self .optimizer_name
189
+
182
190
torch .save (state_dict , filename )
191
+ if shared .opts .save_optimizer_state and self .optimizer_state_dict :
192
+ optimizer_saved_dict ['hash' ] = sd_models .model_hash (filename )
193
+ optimizer_saved_dict ['optimizer_state_dict' ] = self .optimizer_state_dict
194
+ torch .save (optimizer_saved_dict , filename + '.optim' )
183
195
184
196
def load (self , filename ):
185
197
self .filename = filename
@@ -202,6 +214,18 @@ def load(self, filename):
202
214
print (f"Activate last layer is set to { self .activate_output } " )
203
215
self .last_layer_dropout = state_dict .get ('last_layer_dropout' , False )
204
216
217
+ optimizer_saved_dict = torch .load (self .filename + '.optim' , map_location = 'cpu' ) if os .path .exists (self .filename + '.optim' ) else {}
218
+ self .optimizer_name = optimizer_saved_dict .get ('optimizer_name' , 'AdamW' )
219
+ print (f"Optimizer name is { self .optimizer_name } " )
220
+ if sd_models .model_hash (filename ) == optimizer_saved_dict .get ('hash' , None ):
221
+ self .optimizer_state_dict = optimizer_saved_dict .get ('optimizer_state_dict' , None )
222
+ else :
223
+ self .optimizer_state_dict = None
224
+ if self .optimizer_state_dict :
225
+ print ("Loaded existing optimizer from checkpoint" )
226
+ else :
227
+ print ("No saved optimizer exists in checkpoint" )
228
+
205
229
for size , sd in state_dict .items ():
206
230
if type (size ) == int :
207
231
self .layers [size ] = (
@@ -219,11 +243,11 @@ def load(self, filename):
219
243
220
244
def list_hypernetworks (path ):
221
245
res = {}
222
- for filename in glob .iglob (os .path .join (path , '**/*.pt' ), recursive = True ):
246
+ for filename in sorted ( glob .iglob (os .path .join (path , '**/*.pt' ), recursive = True ) ):
223
247
name = os .path .splitext (os .path .basename (filename ))[0 ]
224
248
# Prevent a hypothetical "None.pt" from being listed.
225
249
if name != "None" :
226
- res [name ] = filename
250
+ res [name + f"( { sd_models . model_hash ( filename ) } )" ] = filename
227
251
return res
228
252
229
253
@@ -358,6 +382,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
358
382
shared .state .textinfo = "Initializing hypernetwork training..."
359
383
shared .state .job_count = steps
360
384
385
+ hypernetwork_name = hypernetwork_name .rsplit ('(' , 1 )[0 ]
361
386
filename = os .path .join (shared .cmd_opts .hypernetwork_dir , f'{ hypernetwork_name } .pt' )
362
387
363
388
log_directory = os .path .join (log_directory , datetime .datetime .now ().strftime ("%Y-%m-%d" ), hypernetwork_name )
@@ -404,8 +429,22 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
404
429
weights = hypernetwork .weights ()
405
430
for weight in weights :
406
431
weight .requires_grad = True
407
- # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
408
- optimizer = torch .optim .AdamW (weights , lr = scheduler .learn_rate )
432
+
433
+ # Here we use optimizer from saved HN, or we can specify as UI option.
434
+ if hypernetwork .optimizer_name in optimizer_dict :
435
+ optimizer = optimizer_dict [hypernetwork .optimizer_name ](params = weights , lr = scheduler .learn_rate )
436
+ optimizer_name = hypernetwork .optimizer_name
437
+ else :
438
+ print (f"Optimizer type { hypernetwork .optimizer_name } is not defined!" )
439
+ optimizer = torch .optim .AdamW (params = weights , lr = scheduler .learn_rate )
440
+ optimizer_name = 'AdamW'
441
+
442
+ if hypernetwork .optimizer_state_dict : # This line must be changed if Optimizer type can be different from saved optimizer.
443
+ try :
444
+ optimizer .load_state_dict (hypernetwork .optimizer_state_dict )
445
+ except RuntimeError as e :
446
+ print ("Cannot resume from saved optimizer!" )
447
+ print (e )
409
448
410
449
steps_without_grad = 0
411
450
@@ -467,7 +506,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
467
506
# Before saving, change name to match current checkpoint.
468
507
hypernetwork_name_every = f'{ hypernetwork_name } -{ steps_done } '
469
508
last_saved_file = os .path .join (hypernetwork_dir , f'{ hypernetwork_name_every } .pt' )
509
+ hypernetwork .optimizer_name = optimizer_name
510
+ if shared .opts .save_optimizer_state :
511
+ hypernetwork .optimizer_state_dict = optimizer .state_dict ()
470
512
save_hypernetwork (hypernetwork , checkpoint , hypernetwork_name , last_saved_file )
513
+ hypernetwork .optimizer_state_dict = None # dereference it after saving, to save memory.
471
514
472
515
textual_inversion .write_loss (log_directory , "hypernetwork_loss.csv" , hypernetwork .step , len (ds ), {
473
516
"loss" : f"{ previous_mean_loss :.7f} " ,
@@ -530,8 +573,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
530
573
report_statistics (loss_dict )
531
574
532
575
filename = os .path .join (shared .cmd_opts .hypernetwork_dir , f'{ hypernetwork_name } .pt' )
576
+ hypernetwork .optimizer_name = optimizer_name
577
+ if shared .opts .save_optimizer_state :
578
+ hypernetwork .optimizer_state_dict = optimizer .state_dict ()
533
579
save_hypernetwork (hypernetwork , checkpoint , hypernetwork_name , filename )
534
-
580
+ del optimizer
581
+ hypernetwork .optimizer_state_dict = None # dereference it after saving, to save memory.
535
582
return hypernetwork , filename
536
583
537
584
def save_hypernetwork (hypernetwork , checkpoint , hypernetwork_name , filename ):
0 commit comments