Skip to content

Commit 0d07cbf

Browse files
authored
I blame code autocomplete
1 parent 0abb39f commit 0d07cbf

File tree

1 file changed

+27
-49
lines changed

1 file changed

+27
-49
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 27 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,9 @@ class HypernetworkModule(torch.nn.Module):
3333
"tanh": torch.nn.Tanh,
3434
"sigmoid": torch.nn.Sigmoid,
3535
}
36-
activation_dict.update(
37-
{cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if
38-
inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
36+
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
3937

40-
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
41-
add_layer_norm=False, use_dropout=False):
38+
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
4239
super().__init__()
4340

4441
assert layer_structure is not None, "layer_structure must not be None"
@@ -49,7 +46,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
4946
for i in range(len(layer_structure) - 1):
5047

5148
# Add a fully-connected layer
52-
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i + 1])))
49+
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
5350

5451
# Add an activation func
5552
if activation_func == "linear" or activation_func is None:
@@ -61,7 +58,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
6158

6259
# Add layer normalization
6360
if add_layer_norm:
64-
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i + 1])))
61+
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
6562

6663
# Add dropout expect last layer
6764
if use_dropout and i < len(layer_structure) - 3:
@@ -130,8 +127,7 @@ class Hypernetwork:
130127
filename = None
131128
name = None
132129

133-
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None,
134-
add_layer_norm=False, use_dropout=False):
130+
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
135131
self.filename = None
136132
self.name = name
137133
self.layers = {}
@@ -146,10 +142,8 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
146142

147143
for size in enable_sizes or []:
148144
self.layers[size] = (
149-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
150-
self.add_layer_norm, self.use_dropout),
151-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
152-
self.add_layer_norm, self.use_dropout),
145+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
146+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
153147
)
154148

155149
def weights(self):
@@ -196,15 +190,13 @@ def load(self, filename):
196190
self.add_layer_norm = state_dict.get('is_layer_norm', False)
197191
print(f"Layer norm is set to {self.add_layer_norm}")
198192
self.use_dropout = state_dict.get('use_dropout', False)
199-
print(f"Dropout usage is set to {self.use_dropout}")
193+
print(f"Dropout usage is set to {self.use_dropout}" )
200194

201195
for size, sd in state_dict.items():
202196
if type(size) == int:
203197
self.layers[size] = (
204-
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
205-
self.add_layer_norm, self.use_dropout),
206-
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
207-
self.add_layer_norm, self.use_dropout),
198+
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
199+
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
208200
)
209201

210202
self.name = state_dict.get('name', self.name)
@@ -316,7 +308,7 @@ def statistics(data):
316308
std = 0
317309
else:
318310
std = stdev(data)
319-
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std / (len(data) ** 0.5):.3f})"
311+
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
320312
recent_data = data[-32:]
321313
if len(recent_data) < 2:
322314
std = 0
@@ -326,7 +318,7 @@ def statistics(data):
326318
return total_information, recent_information
327319

328320

329-
def report_statistics(loss_info: dict):
321+
def report_statistics(loss_info:dict):
330322
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
331323
for key in keys:
332324
try:
@@ -338,18 +330,14 @@ def report_statistics(loss_info: dict):
338330
print(e)
339331

340332

341-
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width,
342-
training_height, steps, create_image_every, save_hypernetwork_every, template_file,
343-
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
344-
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
333+
334+
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
345335
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
346336
from modules import images
347337

348338
save_hypernetwork_every = save_hypernetwork_every or 0
349339
create_image_every = create_image_every or 0
350-
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps,
351-
save_hypernetwork_every, create_image_every, log_directory,
352-
name="hypernetwork")
340+
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
353341

354342
path = shared.hypernetworks.get(hypernetwork_name, None)
355343
shared.loaded_hypernetwork = Hypernetwork()
@@ -384,29 +372,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
384372
return hypernetwork, filename
385373

386374
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
387-
375+
388376
# dataset loading may take a while, so input validations and early returns should be done before this
389377
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
390378
with torch.autocast("cuda"):
391-
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width,
392-
height=training_height,
393-
repeats=shared.opts.training_image_repeats_per_epoch,
394-
placeholder_token=hypernetwork_name,
395-
model=shared.sd_model, device=devices.device,
396-
template_file=template_file, include_cond=True,
397-
batch_size=batch_size)
379+
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
398380

399381
if unload:
400382
shared.sd_model.cond_stage_model.to(devices.cpu)
401383
shared.sd_model.first_stage_model.to(devices.cpu)
402384

403385
size = len(ds.indexes)
404-
loss_dict = defaultdict(lambda: deque(maxlen=1024))
386+
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
405387
losses = torch.zeros((size,))
406388
previous_mean_losses = [0]
407389
previous_mean_loss = 0
408390
print("Mean loss of {} elements".format(size))
409-
391+
410392
weights = hypernetwork.weights()
411393
for weight in weights:
412394
weight.requires_grad = True
@@ -425,7 +407,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
425407
if len(loss_dict) > 0:
426408
previous_mean_losses = [i[-1] for i in loss_dict.values()]
427409
previous_mean_loss = mean(previous_mean_losses)
428-
410+
429411
scheduler.apply(optimizer, hypernetwork.step)
430412
if scheduler.finished:
431413
break
@@ -444,7 +426,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
444426
losses[hypernetwork.step % losses.shape[0]] = loss.item()
445427
for entry in entries:
446428
loss_dict[entry.filename].append(loss.item())
447-
429+
448430
optimizer.zero_grad()
449431
weights[0].grad = None
450432
loss.backward()
@@ -459,9 +441,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
459441

460442
steps_done = hypernetwork.step + 1
461443

462-
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
444+
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
463445
raise RuntimeError("Loss diverged.")
464-
446+
465447
if len(previous_mean_losses) > 1:
466448
std = stdev(previous_mean_losses)
467449
else:
@@ -510,18 +492,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
510492
preview_text = p.prompt
511493

512494
processed = processing.process_images(p)
513-
image = processed.images[0] if len(processed.images) > 0 else None
495+
image = processed.images[0] if len(processed.images)>0 else None
514496

515497
if unload:
516498
shared.sd_model.cond_stage_model.to(devices.cpu)
517499
shared.sd_model.first_stage_model.to(devices.cpu)
518500

519501
if image is not None:
520502
shared.state.current_image = image
521-
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
522-
shared.opts.samples_format, processed.infotexts[0],
523-
p=p, forced_filename=forced_filename,
524-
save_to_dirs=False)
503+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
525504
last_saved_image += f", prompt: {preview_text}"
526505

527506
shared.state.job_no = hypernetwork.step
@@ -535,15 +514,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
535514
Last saved image: {html.escape(last_saved_image)}<br/>
536515
</p>
537516
"""
538-
517+
539518
report_statistics(loss_dict)
540519

541520
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
542521
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
543522

544523
return hypernetwork, filename
545524

546-
547525
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
548526
old_hypernetwork_name = hypernetwork.name
549527
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
@@ -557,4 +535,4 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
557535
hypernetwork.sd_checkpoint = old_sd_checkpoint
558536
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
559537
hypernetwork.name = old_hypernetwork_name
560-
raise
538+
raise

0 commit comments

Comments
 (0)