Skip to content

Commit 9d96d7d

Browse files
committed
resolve conflicts
1 parent 20194fd commit 9d96d7d

File tree

1 file changed

+38
-6
lines changed

1 file changed

+38
-6
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from collections import defaultdict, deque
2222
from statistics import stdev, mean
2323

24+
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
2425

2526
class HypernetworkModule(torch.nn.Module):
2627
multiplier = 1.0
@@ -139,6 +140,8 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
139140
self.weight_init = weight_init
140141
self.add_layer_norm = add_layer_norm
141142
self.use_dropout = use_dropout
143+
self.optimizer_name = None
144+
self.optimizer_state_dict = None
142145

143146
for size in enable_sizes or []:
144147
self.layers[size] = (
@@ -171,6 +174,10 @@ def save(self, filename):
171174
state_dict['use_dropout'] = self.use_dropout
172175
state_dict['sd_checkpoint'] = self.sd_checkpoint
173176
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
177+
if self.optimizer_name is not None:
178+
state_dict['optimizer_name'] = self.optimizer_name
179+
if self.optimizer_state_dict:
180+
state_dict['optimizer_state_dict'] = self.optimizer_state_dict
174181

175182
torch.save(state_dict, filename)
176183

@@ -190,7 +197,14 @@ def load(self, filename):
190197
self.add_layer_norm = state_dict.get('is_layer_norm', False)
191198
print(f"Layer norm is set to {self.add_layer_norm}")
192199
self.use_dropout = state_dict.get('use_dropout', False)
193-
print(f"Dropout usage is set to {self.use_dropout}" )
200+
print(f"Dropout usage is set to {self.use_dropout}")
201+
self.optimizer_name = state_dict.get('optimizer_name', 'AdamW')
202+
print(f"Optimizer name is {self.optimizer_name}")
203+
self.optimizer_state_dict = state_dict.get('optimizer_state_dict', None)
204+
if self.optimizer_state_dict:
205+
print("Loaded existing optimizer from checkpoint")
206+
else:
207+
print("No saved optimizer exists in checkpoint")
194208

195209
for size, sd in state_dict.items():
196210
if type(size) == int:
@@ -392,8 +406,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
392406
weights = hypernetwork.weights()
393407
for weight in weights:
394408
weight.requires_grad = True
395-
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
396-
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
409+
# Here we use optimizer from saved HN, or we can specify as UI option.
410+
if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict:
411+
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
412+
else:
413+
print(f"Optimizer type {optimizer_name} is not defined!")
414+
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
415+
optimizer_name = 'AdamW'
416+
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
417+
try:
418+
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
419+
except RuntimeError as e:
420+
print("Cannot resume from saved optimizer!")
421+
print(e)
397422

398423
steps_without_grad = 0
399424

@@ -455,8 +480,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
455480
# Before saving, change name to match current checkpoint.
456481
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
457482
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
483+
hypernetwork.optimizer_name = optimizer_name
484+
if shared.opts.save_optimizer_state:
485+
hypernetwork.optimizer_state_dict = optimizer.state_dict()
458486
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
459-
487+
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
460488
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
461489
"loss": f"{previous_mean_loss:.7f}",
462490
"learn_rate": scheduler.learn_rate
@@ -514,14 +542,18 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
514542
Last saved image: {html.escape(last_saved_image)}<br/>
515543
</p>
516544
"""
517-
518545
report_statistics(loss_dict)
519546

520547
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
548+
hypernetwork.optimizer_name = optimizer_name
549+
if shared.opts.save_optimizer_state:
550+
hypernetwork.optimizer_state_dict = optimizer.state_dict()
521551
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
522-
552+
del optimizer
553+
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
523554
return hypernetwork, filename
524555

556+
525557
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
526558
old_hypernetwork_name = hypernetwork.name
527559
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None

0 commit comments

Comments
 (0)