21
21
from collections import defaultdict , deque
22
22
from statistics import stdev , mean
23
23
24
- optimizer_dict = {optim_name : cls_obj for optim_name , cls_obj in inspect .getmembers (torch .optim , inspect .isclass ) if optim_name != "Optimizer" }
25
24
26
25
class HypernetworkModule (torch .nn .Module ):
27
26
multiplier = 1.0
@@ -34,9 +33,12 @@ class HypernetworkModule(torch.nn.Module):
34
33
"tanh" : torch .nn .Tanh ,
35
34
"sigmoid" : torch .nn .Sigmoid ,
36
35
}
37
- activation_dict .update ({cls_name .lower (): cls_obj for cls_name , cls_obj in inspect .getmembers (torch .nn .modules .activation ) if inspect .isclass (cls_obj ) and cls_obj .__module__ == 'torch.nn.modules.activation' })
36
+ activation_dict .update (
37
+ {cls_name .lower (): cls_obj for cls_name , cls_obj in inspect .getmembers (torch .nn .modules .activation ) if
38
+ inspect .isclass (cls_obj ) and cls_obj .__module__ == 'torch.nn.modules.activation' })
38
39
39
- def __init__ (self , dim , state_dict = None , layer_structure = None , activation_func = None , weight_init = 'Normal' , add_layer_norm = False , use_dropout = False ):
40
+ def __init__ (self , dim , state_dict = None , layer_structure = None , activation_func = None , weight_init = 'Normal' ,
41
+ add_layer_norm = False , use_dropout = False ):
40
42
super ().__init__ ()
41
43
42
44
assert layer_structure is not None , "layer_structure must not be None"
@@ -47,7 +49,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
47
49
for i in range (len (layer_structure ) - 1 ):
48
50
49
51
# Add a fully-connected layer
50
- linears .append (torch .nn .Linear (int (dim * layer_structure [i ]), int (dim * layer_structure [i + 1 ])))
52
+ linears .append (torch .nn .Linear (int (dim * layer_structure [i ]), int (dim * layer_structure [i + 1 ])))
51
53
52
54
# Add an activation func
53
55
if activation_func == "linear" or activation_func is None :
@@ -59,7 +61,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
59
61
60
62
# Add layer normalization
61
63
if add_layer_norm :
62
- linears .append (torch .nn .LayerNorm (int (dim * layer_structure [i + 1 ])))
64
+ linears .append (torch .nn .LayerNorm (int (dim * layer_structure [i + 1 ])))
63
65
64
66
# Add dropout expect last layer
65
67
if use_dropout and i < len (layer_structure ) - 3 :
@@ -128,7 +130,8 @@ class Hypernetwork:
128
130
filename = None
129
131
name = None
130
132
131
- def __init__ (self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , weight_init = None , add_layer_norm = False , use_dropout = False ):
133
+ def __init__ (self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , weight_init = None ,
134
+ add_layer_norm = False , use_dropout = False ):
132
135
self .filename = None
133
136
self .name = name
134
137
self .layers = {}
@@ -140,13 +143,13 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
140
143
self .weight_init = weight_init
141
144
self .add_layer_norm = add_layer_norm
142
145
self .use_dropout = use_dropout
143
- self .optimizer_name = None
144
- self .optimizer_state_dict = None
145
146
146
147
for size in enable_sizes or []:
147
148
self .layers [size ] = (
148
- HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout ),
149
- HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout ),
149
+ HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init ,
150
+ self .add_layer_norm , self .use_dropout ),
151
+ HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init ,
152
+ self .add_layer_norm , self .use_dropout ),
150
153
)
151
154
152
155
def weights (self ):
@@ -161,7 +164,6 @@ def weights(self):
161
164
162
165
def save (self , filename ):
163
166
state_dict = {}
164
- optimizer_saved_dict = {}
165
167
166
168
for k , v in self .layers .items ():
167
169
state_dict [k ] = (v [0 ].state_dict (), v [1 ].state_dict ())
@@ -175,14 +177,8 @@ def save(self, filename):
175
177
state_dict ['use_dropout' ] = self .use_dropout
176
178
state_dict ['sd_checkpoint' ] = self .sd_checkpoint
177
179
state_dict ['sd_checkpoint_name' ] = self .sd_checkpoint_name
178
- if self .optimizer_name is not None :
179
- optimizer_saved_dict ['optimizer_name' ] = self .optimizer_name
180
180
181
181
torch .save (state_dict , filename )
182
- if self .optimizer_state_dict :
183
- optimizer_saved_dict ['hash' ] = sd_models .model_hash (filename )
184
- optimizer_saved_dict ['optimizer_state_dict' ] = self .optimizer_state_dict
185
- torch .save (optimizer_saved_dict , filename + '.optim' )
186
182
187
183
def load (self , filename ):
188
184
self .filename = filename
@@ -202,23 +198,13 @@ def load(self, filename):
202
198
self .use_dropout = state_dict .get ('use_dropout' , False )
203
199
print (f"Dropout usage is set to { self .use_dropout } " )
204
200
205
- optimizer_saved_dict = torch .load (self .filename + '.optim' , map_location = 'cpu' ) if os .path .exists (self .filename + '.optim' ) else {}
206
- self .optimizer_name = optimizer_saved_dict .get ('optimizer_name' , 'AdamW' )
207
- print (f"Optimizer name is { self .optimizer_name } " )
208
- if sd_models .model_hash (filename ) == optimizer_saved_dict .get ('hash' , None ):
209
- self .optimizer_state_dict = optimizer_saved_dict .get ('optimizer_state_dict' , None )
210
- else :
211
- self .optimizer_state_dict = None
212
- if self .optimizer_state_dict :
213
- print ("Loaded existing optimizer from checkpoint" )
214
- else :
215
- print ("No saved optimizer exists in checkpoint" )
216
-
217
201
for size , sd in state_dict .items ():
218
202
if type (size ) == int :
219
203
self .layers [size ] = (
220
- HypernetworkModule (size , sd [0 ], self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout ),
221
- HypernetworkModule (size , sd [1 ], self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout ),
204
+ HypernetworkModule (size , sd [0 ], self .layer_structure , self .activation_func , self .weight_init ,
205
+ self .add_layer_norm , self .use_dropout ),
206
+ HypernetworkModule (size , sd [1 ], self .layer_structure , self .activation_func , self .weight_init ,
207
+ self .add_layer_norm , self .use_dropout ),
222
208
)
223
209
224
210
self .name = state_dict .get ('name' , self .name )
@@ -233,7 +219,7 @@ def list_hypernetworks(path):
233
219
name = os .path .splitext (os .path .basename (filename ))[0 ]
234
220
# Prevent a hypothetical "None.pt" from being listed.
235
221
if name != "None" :
236
- res [name + f"( { sd_models . model_hash ( filename ) } )" ] = filename
222
+ res [name ] = filename
237
223
return res
238
224
239
225
@@ -330,7 +316,7 @@ def statistics(data):
330
316
std = 0
331
317
else :
332
318
std = stdev (data )
333
- total_information = f"loss:{ mean (data ):.3f} " + u"\u00B1 " + f"({ std / (len (data ) ** 0.5 ):.3f} )"
319
+ total_information = f"loss:{ mean (data ):.3f} " + u"\u00B1 " + f"({ std / (len (data ) ** 0.5 ):.3f} )"
334
320
recent_data = data [- 32 :]
335
321
if len (recent_data ) < 2 :
336
322
std = 0
@@ -340,7 +326,7 @@ def statistics(data):
340
326
return total_information , recent_information
341
327
342
328
343
- def report_statistics (loss_info :dict ):
329
+ def report_statistics (loss_info : dict ):
344
330
keys = sorted (loss_info .keys (), key = lambda x : sum (loss_info [x ]) / len (loss_info [x ]))
345
331
for key in keys :
346
332
try :
@@ -352,14 +338,18 @@ def report_statistics(loss_info:dict):
352
338
print (e )
353
339
354
340
355
-
356
- def train_hypernetwork (hypernetwork_name , learn_rate , batch_size , data_root , log_directory , training_width , training_height , steps , create_image_every , save_hypernetwork_every , template_file , preview_from_txt2img , preview_prompt , preview_negative_prompt , preview_steps , preview_sampler_index , preview_cfg_scale , preview_seed , preview_width , preview_height ):
341
+ def train_hypernetwork (hypernetwork_name , learn_rate , batch_size , data_root , log_directory , training_width ,
342
+ training_height , steps , create_image_every , save_hypernetwork_every , template_file ,
343
+ preview_from_txt2img , preview_prompt , preview_negative_prompt , preview_steps ,
344
+ preview_sampler_index , preview_cfg_scale , preview_seed , preview_width , preview_height ):
357
345
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
358
346
from modules import images
359
347
360
348
save_hypernetwork_every = save_hypernetwork_every or 0
361
349
create_image_every = create_image_every or 0
362
- textual_inversion .validate_train_inputs (hypernetwork_name , learn_rate , batch_size , data_root , template_file , steps , save_hypernetwork_every , create_image_every , log_directory , name = "hypernetwork" )
350
+ textual_inversion .validate_train_inputs (hypernetwork_name , learn_rate , batch_size , data_root , template_file , steps ,
351
+ save_hypernetwork_every , create_image_every , log_directory ,
352
+ name = "hypernetwork" )
363
353
364
354
path = shared .hypernetworks .get (hypernetwork_name , None )
365
355
shared .loaded_hypernetwork = Hypernetwork ()
@@ -379,7 +369,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
379
369
else :
380
370
hypernetwork_dir = None
381
371
382
- hypernetwork_name = hypernetwork_name .rsplit ('(' , 1 )[0 ]
383
372
if create_image_every > 0 :
384
373
images_dir = os .path .join (log_directory , "images" )
385
374
os .makedirs (images_dir , exist_ok = True )
@@ -395,39 +384,34 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
395
384
return hypernetwork , filename
396
385
397
386
scheduler = LearnRateScheduler (learn_rate , steps , ititial_step )
398
-
387
+
399
388
# dataset loading may take a while, so input validations and early returns should be done before this
400
389
shared .state .textinfo = f"Preparing dataset from { html .escape (data_root )} ..."
401
390
with torch .autocast ("cuda" ):
402
- ds = modules .textual_inversion .dataset .PersonalizedBase (data_root = data_root , width = training_width , height = training_height , repeats = shared .opts .training_image_repeats_per_epoch , placeholder_token = hypernetwork_name , model = shared .sd_model , device = devices .device , template_file = template_file , include_cond = True , batch_size = batch_size )
391
+ ds = modules .textual_inversion .dataset .PersonalizedBase (data_root = data_root , width = training_width ,
392
+ height = training_height ,
393
+ repeats = shared .opts .training_image_repeats_per_epoch ,
394
+ placeholder_token = hypernetwork_name ,
395
+ model = shared .sd_model , device = devices .device ,
396
+ template_file = template_file , include_cond = True ,
397
+ batch_size = batch_size )
403
398
404
399
if unload :
405
400
shared .sd_model .cond_stage_model .to (devices .cpu )
406
401
shared .sd_model .first_stage_model .to (devices .cpu )
407
402
408
403
size = len (ds .indexes )
409
- loss_dict = defaultdict (lambda : deque (maxlen = 1024 ))
404
+ loss_dict = defaultdict (lambda : deque (maxlen = 1024 ))
410
405
losses = torch .zeros ((size ,))
411
406
previous_mean_losses = [0 ]
412
407
previous_mean_loss = 0
413
408
print ("Mean loss of {} elements" .format (size ))
414
-
409
+
415
410
weights = hypernetwork .weights ()
416
411
for weight in weights :
417
412
weight .requires_grad = True
418
- # Here we use optimizer from saved HN, or we can specify as UI option.
419
- if (optimizer_name := hypernetwork .optimizer_name ) in optimizer_dict :
420
- optimizer = optimizer_dict [hypernetwork .optimizer_name ](params = weights , lr = scheduler .learn_rate )
421
- else :
422
- print (f"Optimizer type { optimizer_name } is not defined!" )
423
- optimizer = torch .optim .AdamW (params = weights , lr = scheduler .learn_rate )
424
- optimizer_name = 'AdamW'
425
- if hypernetwork .optimizer_state_dict : # This line must be changed if Optimizer type can be different from saved optimizer.
426
- try :
427
- optimizer .load_state_dict (hypernetwork .optimizer_state_dict )
428
- except RuntimeError as e :
429
- print ("Cannot resume from saved optimizer!" )
430
- print (e )
413
+ # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
414
+ optimizer = torch .optim .AdamW (weights , lr = scheduler .learn_rate )
431
415
432
416
steps_without_grad = 0
433
417
@@ -441,7 +425,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
441
425
if len (loss_dict ) > 0 :
442
426
previous_mean_losses = [i [- 1 ] for i in loss_dict .values ()]
443
427
previous_mean_loss = mean (previous_mean_losses )
444
-
428
+
445
429
scheduler .apply (optimizer , hypernetwork .step )
446
430
if scheduler .finished :
447
431
break
@@ -460,7 +444,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
460
444
losses [hypernetwork .step % losses .shape [0 ]] = loss .item ()
461
445
for entry in entries :
462
446
loss_dict [entry .filename ].append (loss .item ())
463
-
447
+
464
448
optimizer .zero_grad ()
465
449
weights [0 ].grad = None
466
450
loss .backward ()
@@ -475,9 +459,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
475
459
476
460
steps_done = hypernetwork .step + 1
477
461
478
- if torch .isnan (losses [hypernetwork .step % losses .shape [0 ]]):
462
+ if torch .isnan (losses [hypernetwork .step % losses .shape [0 ]]):
479
463
raise RuntimeError ("Loss diverged." )
480
-
464
+
481
465
if len (previous_mean_losses ) > 1 :
482
466
std = stdev (previous_mean_losses )
483
467
else :
@@ -489,11 +473,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
489
473
# Before saving, change name to match current checkpoint.
490
474
hypernetwork_name_every = f'{ hypernetwork_name } -{ steps_done } '
491
475
last_saved_file = os .path .join (hypernetwork_dir , f'{ hypernetwork_name_every } .pt' )
492
- hypernetwork .optimizer_name = optimizer_name
493
- if shared .opts .save_optimizer_state :
494
- hypernetwork .optimizer_state_dict = optimizer .state_dict ()
495
476
save_hypernetwork (hypernetwork , checkpoint , hypernetwork_name , last_saved_file )
496
- hypernetwork . optimizer_state_dict = None # dereference it after saving, to save memory.
477
+
497
478
textual_inversion .write_loss (log_directory , "hypernetwork_loss.csv" , hypernetwork .step , len (ds ), {
498
479
"loss" : f"{ previous_mean_loss :.7f} " ,
499
480
"learn_rate" : scheduler .learn_rate
@@ -529,15 +510,18 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
529
510
preview_text = p .prompt
530
511
531
512
processed = processing .process_images (p )
532
- image = processed .images [0 ] if len (processed .images )> 0 else None
513
+ image = processed .images [0 ] if len (processed .images ) > 0 else None
533
514
534
515
if unload :
535
516
shared .sd_model .cond_stage_model .to (devices .cpu )
536
517
shared .sd_model .first_stage_model .to (devices .cpu )
537
518
538
519
if image is not None :
539
520
shared .state .current_image = image
540
- last_saved_image , last_text_info = images .save_image (image , images_dir , "" , p .seed , p .prompt , shared .opts .samples_format , processed .infotexts [0 ], p = p , forced_filename = forced_filename , save_to_dirs = False )
521
+ last_saved_image , last_text_info = images .save_image (image , images_dir , "" , p .seed , p .prompt ,
522
+ shared .opts .samples_format , processed .infotexts [0 ],
523
+ p = p , forced_filename = forced_filename ,
524
+ save_to_dirs = False )
541
525
last_saved_image += f", prompt: { preview_text } "
542
526
543
527
shared .state .job_no = hypernetwork .step
@@ -551,15 +535,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
551
535
Last saved image: { html .escape (last_saved_image )} <br/>
552
536
</p>
553
537
"""
538
+
554
539
report_statistics (loss_dict )
555
540
556
541
filename = os .path .join (shared .cmd_opts .hypernetwork_dir , f'{ hypernetwork_name } .pt' )
557
- hypernetwork .optimizer_name = optimizer_name
558
- if shared .opts .save_optimizer_state :
559
- hypernetwork .optimizer_state_dict = optimizer .state_dict ()
560
542
save_hypernetwork (hypernetwork , checkpoint , hypernetwork_name , filename )
561
- del optimizer
562
- hypernetwork .optimizer_state_dict = None # dereference it after saving, to save memory.
543
+
563
544
return hypernetwork , filename
564
545
565
546
@@ -576,4 +557,4 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
576
557
hypernetwork .sd_checkpoint = old_sd_checkpoint
577
558
hypernetwork .sd_checkpoint_name = old_sd_checkpoint_name
578
559
hypernetwork .name = old_hypernetwork_name
579
- raise
560
+ raise
0 commit comments