Skip to content

Commit 0abb39f

Browse files
committed
resolve conflict - first revert
1 parent 1764ac3 commit 0abb39f

File tree

1 file changed

+52
-71
lines changed

1 file changed

+52
-71
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 52 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from collections import defaultdict, deque
2222
from statistics import stdev, mean
2323

24-
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
2524

2625
class HypernetworkModule(torch.nn.Module):
2726
multiplier = 1.0
@@ -34,9 +33,12 @@ class HypernetworkModule(torch.nn.Module):
3433
"tanh": torch.nn.Tanh,
3534
"sigmoid": torch.nn.Sigmoid,
3635
}
37-
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
36+
activation_dict.update(
37+
{cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if
38+
inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
3839

39-
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
40+
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
41+
add_layer_norm=False, use_dropout=False):
4042
super().__init__()
4143

4244
assert layer_structure is not None, "layer_structure must not be None"
@@ -47,7 +49,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
4749
for i in range(len(layer_structure) - 1):
4850

4951
# Add a fully-connected layer
50-
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
52+
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i + 1])))
5153

5254
# Add an activation func
5355
if activation_func == "linear" or activation_func is None:
@@ -59,7 +61,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
5961

6062
# Add layer normalization
6163
if add_layer_norm:
62-
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
64+
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i + 1])))
6365

6466
# Add dropout expect last layer
6567
if use_dropout and i < len(layer_structure) - 3:
@@ -128,7 +130,8 @@ class Hypernetwork:
128130
filename = None
129131
name = None
130132

131-
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
133+
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None,
134+
add_layer_norm=False, use_dropout=False):
132135
self.filename = None
133136
self.name = name
134137
self.layers = {}
@@ -140,13 +143,13 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
140143
self.weight_init = weight_init
141144
self.add_layer_norm = add_layer_norm
142145
self.use_dropout = use_dropout
143-
self.optimizer_name = None
144-
self.optimizer_state_dict = None
145146

146147
for size in enable_sizes or []:
147148
self.layers[size] = (
148-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
149-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
149+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
150+
self.add_layer_norm, self.use_dropout),
151+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
152+
self.add_layer_norm, self.use_dropout),
150153
)
151154

152155
def weights(self):
@@ -161,7 +164,6 @@ def weights(self):
161164

162165
def save(self, filename):
163166
state_dict = {}
164-
optimizer_saved_dict = {}
165167

166168
for k, v in self.layers.items():
167169
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
@@ -175,14 +177,8 @@ def save(self, filename):
175177
state_dict['use_dropout'] = self.use_dropout
176178
state_dict['sd_checkpoint'] = self.sd_checkpoint
177179
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
178-
if self.optimizer_name is not None:
179-
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
180180

181181
torch.save(state_dict, filename)
182-
if self.optimizer_state_dict:
183-
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
184-
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
185-
torch.save(optimizer_saved_dict, filename + '.optim')
186182

187183
def load(self, filename):
188184
self.filename = filename
@@ -202,23 +198,13 @@ def load(self, filename):
202198
self.use_dropout = state_dict.get('use_dropout', False)
203199
print(f"Dropout usage is set to {self.use_dropout}")
204200

205-
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
206-
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
207-
print(f"Optimizer name is {self.optimizer_name}")
208-
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
209-
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
210-
else:
211-
self.optimizer_state_dict = None
212-
if self.optimizer_state_dict:
213-
print("Loaded existing optimizer from checkpoint")
214-
else:
215-
print("No saved optimizer exists in checkpoint")
216-
217201
for size, sd in state_dict.items():
218202
if type(size) == int:
219203
self.layers[size] = (
220-
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
221-
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
204+
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
205+
self.add_layer_norm, self.use_dropout),
206+
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
207+
self.add_layer_norm, self.use_dropout),
222208
)
223209

224210
self.name = state_dict.get('name', self.name)
@@ -233,7 +219,7 @@ def list_hypernetworks(path):
233219
name = os.path.splitext(os.path.basename(filename))[0]
234220
# Prevent a hypothetical "None.pt" from being listed.
235221
if name != "None":
236-
res[name + f"({sd_models.model_hash(filename)})"] = filename
222+
res[name] = filename
237223
return res
238224

239225

@@ -330,7 +316,7 @@ def statistics(data):
330316
std = 0
331317
else:
332318
std = stdev(data)
333-
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
319+
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std / (len(data) ** 0.5):.3f})"
334320
recent_data = data[-32:]
335321
if len(recent_data) < 2:
336322
std = 0
@@ -340,7 +326,7 @@ def statistics(data):
340326
return total_information, recent_information
341327

342328

343-
def report_statistics(loss_info:dict):
329+
def report_statistics(loss_info: dict):
344330
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
345331
for key in keys:
346332
try:
@@ -352,14 +338,18 @@ def report_statistics(loss_info:dict):
352338
print(e)
353339

354340

355-
356-
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
341+
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width,
342+
training_height, steps, create_image_every, save_hypernetwork_every, template_file,
343+
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
344+
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
357345
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
358346
from modules import images
359347

360348
save_hypernetwork_every = save_hypernetwork_every or 0
361349
create_image_every = create_image_every or 0
362-
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
350+
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps,
351+
save_hypernetwork_every, create_image_every, log_directory,
352+
name="hypernetwork")
363353

364354
path = shared.hypernetworks.get(hypernetwork_name, None)
365355
shared.loaded_hypernetwork = Hypernetwork()
@@ -379,7 +369,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
379369
else:
380370
hypernetwork_dir = None
381371

382-
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
383372
if create_image_every > 0:
384373
images_dir = os.path.join(log_directory, "images")
385374
os.makedirs(images_dir, exist_ok=True)
@@ -395,39 +384,34 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
395384
return hypernetwork, filename
396385

397386
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
398-
387+
399388
# dataset loading may take a while, so input validations and early returns should be done before this
400389
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
401390
with torch.autocast("cuda"):
402-
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
391+
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width,
392+
height=training_height,
393+
repeats=shared.opts.training_image_repeats_per_epoch,
394+
placeholder_token=hypernetwork_name,
395+
model=shared.sd_model, device=devices.device,
396+
template_file=template_file, include_cond=True,
397+
batch_size=batch_size)
403398

404399
if unload:
405400
shared.sd_model.cond_stage_model.to(devices.cpu)
406401
shared.sd_model.first_stage_model.to(devices.cpu)
407402

408403
size = len(ds.indexes)
409-
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
404+
loss_dict = defaultdict(lambda: deque(maxlen=1024))
410405
losses = torch.zeros((size,))
411406
previous_mean_losses = [0]
412407
previous_mean_loss = 0
413408
print("Mean loss of {} elements".format(size))
414-
409+
415410
weights = hypernetwork.weights()
416411
for weight in weights:
417412
weight.requires_grad = True
418-
# Here we use optimizer from saved HN, or we can specify as UI option.
419-
if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict:
420-
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
421-
else:
422-
print(f"Optimizer type {optimizer_name} is not defined!")
423-
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
424-
optimizer_name = 'AdamW'
425-
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
426-
try:
427-
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
428-
except RuntimeError as e:
429-
print("Cannot resume from saved optimizer!")
430-
print(e)
413+
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
414+
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
431415

432416
steps_without_grad = 0
433417

@@ -441,7 +425,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
441425
if len(loss_dict) > 0:
442426
previous_mean_losses = [i[-1] for i in loss_dict.values()]
443427
previous_mean_loss = mean(previous_mean_losses)
444-
428+
445429
scheduler.apply(optimizer, hypernetwork.step)
446430
if scheduler.finished:
447431
break
@@ -460,7 +444,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
460444
losses[hypernetwork.step % losses.shape[0]] = loss.item()
461445
for entry in entries:
462446
loss_dict[entry.filename].append(loss.item())
463-
447+
464448
optimizer.zero_grad()
465449
weights[0].grad = None
466450
loss.backward()
@@ -475,9 +459,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
475459

476460
steps_done = hypernetwork.step + 1
477461

478-
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
462+
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
479463
raise RuntimeError("Loss diverged.")
480-
464+
481465
if len(previous_mean_losses) > 1:
482466
std = stdev(previous_mean_losses)
483467
else:
@@ -489,11 +473,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
489473
# Before saving, change name to match current checkpoint.
490474
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
491475
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
492-
hypernetwork.optimizer_name = optimizer_name
493-
if shared.opts.save_optimizer_state:
494-
hypernetwork.optimizer_state_dict = optimizer.state_dict()
495476
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
496-
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
477+
497478
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
498479
"loss": f"{previous_mean_loss:.7f}",
499480
"learn_rate": scheduler.learn_rate
@@ -529,15 +510,18 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
529510
preview_text = p.prompt
530511

531512
processed = processing.process_images(p)
532-
image = processed.images[0] if len(processed.images)>0 else None
513+
image = processed.images[0] if len(processed.images) > 0 else None
533514

534515
if unload:
535516
shared.sd_model.cond_stage_model.to(devices.cpu)
536517
shared.sd_model.first_stage_model.to(devices.cpu)
537518

538519
if image is not None:
539520
shared.state.current_image = image
540-
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
521+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
522+
shared.opts.samples_format, processed.infotexts[0],
523+
p=p, forced_filename=forced_filename,
524+
save_to_dirs=False)
541525
last_saved_image += f", prompt: {preview_text}"
542526

543527
shared.state.job_no = hypernetwork.step
@@ -551,15 +535,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
551535
Last saved image: {html.escape(last_saved_image)}<br/>
552536
</p>
553537
"""
538+
554539
report_statistics(loss_dict)
555540

556541
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
557-
hypernetwork.optimizer_name = optimizer_name
558-
if shared.opts.save_optimizer_state:
559-
hypernetwork.optimizer_state_dict = optimizer.state_dict()
560542
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
561-
del optimizer
562-
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
543+
563544
return hypernetwork, filename
564545

565546

@@ -576,4 +557,4 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
576557
hypernetwork.sd_checkpoint = old_sd_checkpoint
577558
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
578559
hypernetwork.name = old_hypernetwork_name
579-
raise
560+
raise

0 commit comments

Comments
 (0)