21
21
from collections import defaultdict , deque
22
22
from statistics import stdev , mean
23
23
24
+ optimizer_dict = {optim_name : cls_obj for optim_name , cls_obj in inspect .getmembers (torch .optim , inspect .isclass ) if optim_name != "Optimizer" }
24
25
25
26
class HypernetworkModule (torch .nn .Module ):
26
27
multiplier = 1.0
@@ -139,6 +140,8 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
139
140
self .weight_init = weight_init
140
141
self .add_layer_norm = add_layer_norm
141
142
self .use_dropout = use_dropout
143
+ self .optimizer_name = None
144
+ self .optimizer_state_dict = None
142
145
143
146
for size in enable_sizes or []:
144
147
self .layers [size ] = (
@@ -171,6 +174,10 @@ def save(self, filename):
171
174
state_dict ['use_dropout' ] = self .use_dropout
172
175
state_dict ['sd_checkpoint' ] = self .sd_checkpoint
173
176
state_dict ['sd_checkpoint_name' ] = self .sd_checkpoint_name
177
+ if self .optimizer_name is not None :
178
+ state_dict ['optimizer_name' ] = self .optimizer_name
179
+ if self .optimizer_state_dict :
180
+ state_dict ['optimizer_state_dict' ] = self .optimizer_state_dict
174
181
175
182
torch .save (state_dict , filename )
176
183
@@ -190,7 +197,14 @@ def load(self, filename):
190
197
self .add_layer_norm = state_dict .get ('is_layer_norm' , False )
191
198
print (f"Layer norm is set to { self .add_layer_norm } " )
192
199
self .use_dropout = state_dict .get ('use_dropout' , False )
193
- print (f"Dropout usage is set to { self .use_dropout } " )
200
+ print (f"Dropout usage is set to { self .use_dropout } " )
201
+ self .optimizer_name = state_dict .get ('optimizer_name' , 'AdamW' )
202
+ print (f"Optimizer name is { self .optimizer_name } " )
203
+ self .optimizer_state_dict = state_dict .get ('optimizer_state_dict' , None )
204
+ if self .optimizer_state_dict :
205
+ print ("Loaded existing optimizer from checkpoint" )
206
+ else :
207
+ print ("No saved optimizer exists in checkpoint" )
194
208
195
209
for size , sd in state_dict .items ():
196
210
if type (size ) == int :
@@ -392,8 +406,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
392
406
weights = hypernetwork .weights ()
393
407
for weight in weights :
394
408
weight .requires_grad = True
395
- # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
396
- optimizer = torch .optim .AdamW (weights , lr = scheduler .learn_rate )
409
+ # Here we use optimizer from saved HN, or we can specify as UI option.
410
+ if (optimizer_name := hypernetwork .optimizer_name ) in optimizer_dict :
411
+ optimizer = optimizer_dict [hypernetwork .optimizer_name ](params = weights , lr = scheduler .learn_rate )
412
+ else :
413
+ print (f"Optimizer type { optimizer_name } is not defined!" )
414
+ optimizer = torch .optim .AdamW (params = weights , lr = scheduler .learn_rate )
415
+ optimizer_name = 'AdamW'
416
+ if hypernetwork .optimizer_state_dict : # This line must be changed if Optimizer type can be different from saved optimizer.
417
+ try :
418
+ optimizer .load_state_dict (hypernetwork .optimizer_state_dict )
419
+ except RuntimeError as e :
420
+ print ("Cannot resume from saved optimizer!" )
421
+ print (e )
397
422
398
423
steps_without_grad = 0
399
424
@@ -455,8 +480,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
455
480
# Before saving, change name to match current checkpoint.
456
481
hypernetwork_name_every = f'{ hypernetwork_name } -{ steps_done } '
457
482
last_saved_file = os .path .join (hypernetwork_dir , f'{ hypernetwork_name_every } .pt' )
483
+ hypernetwork .optimizer_name = optimizer_name
484
+ if shared .opts .save_optimizer_state :
485
+ hypernetwork .optimizer_state_dict = optimizer .state_dict ()
458
486
save_hypernetwork (hypernetwork , checkpoint , hypernetwork_name , last_saved_file )
459
-
487
+ hypernetwork . optimizer_state_dict = None # dereference it after saving, to save memory.
460
488
textual_inversion .write_loss (log_directory , "hypernetwork_loss.csv" , hypernetwork .step , len (ds ), {
461
489
"loss" : f"{ previous_mean_loss :.7f} " ,
462
490
"learn_rate" : scheduler .learn_rate
@@ -514,14 +542,18 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
514
542
Last saved image: { html .escape (last_saved_image )} <br/>
515
543
</p>
516
544
"""
517
-
518
545
report_statistics (loss_dict )
519
546
520
547
filename = os .path .join (shared .cmd_opts .hypernetwork_dir , f'{ hypernetwork_name } .pt' )
548
+ hypernetwork .optimizer_name = optimizer_name
549
+ if shared .opts .save_optimizer_state :
550
+ hypernetwork .optimizer_state_dict = optimizer .state_dict ()
521
551
save_hypernetwork (hypernetwork , checkpoint , hypernetwork_name , filename )
522
-
552
+ del optimizer
553
+ hypernetwork .optimizer_state_dict = None # dereference it after saving, to save memory.
523
554
return hypernetwork , filename
524
555
556
+
525
557
def save_hypernetwork (hypernetwork , checkpoint , hypernetwork_name , filename ):
526
558
old_hypernetwork_name = hypernetwork .name
527
559
old_sd_checkpoint = hypernetwork .sd_checkpoint if hasattr (hypernetwork , "sd_checkpoint" ) else None
0 commit comments