@@ -33,12 +33,9 @@ class HypernetworkModule(torch.nn.Module):
33
33
"tanh" : torch .nn .Tanh ,
34
34
"sigmoid" : torch .nn .Sigmoid ,
35
35
}
36
- activation_dict .update (
37
- {cls_name .lower (): cls_obj for cls_name , cls_obj in inspect .getmembers (torch .nn .modules .activation ) if
38
- inspect .isclass (cls_obj ) and cls_obj .__module__ == 'torch.nn.modules.activation' })
36
+ activation_dict .update ({cls_name .lower (): cls_obj for cls_name , cls_obj in inspect .getmembers (torch .nn .modules .activation ) if inspect .isclass (cls_obj ) and cls_obj .__module__ == 'torch.nn.modules.activation' })
39
37
40
- def __init__ (self , dim , state_dict = None , layer_structure = None , activation_func = None , weight_init = 'Normal' ,
41
- add_layer_norm = False , use_dropout = False ):
38
+ def __init__ (self , dim , state_dict = None , layer_structure = None , activation_func = None , weight_init = 'Normal' , add_layer_norm = False , use_dropout = False ):
42
39
super ().__init__ ()
43
40
44
41
assert layer_structure is not None , "layer_structure must not be None"
@@ -49,7 +46,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
49
46
for i in range (len (layer_structure ) - 1 ):
50
47
51
48
# Add a fully-connected layer
52
- linears .append (torch .nn .Linear (int (dim * layer_structure [i ]), int (dim * layer_structure [i + 1 ])))
49
+ linears .append (torch .nn .Linear (int (dim * layer_structure [i ]), int (dim * layer_structure [i + 1 ])))
53
50
54
51
# Add an activation func
55
52
if activation_func == "linear" or activation_func is None :
@@ -61,7 +58,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
61
58
62
59
# Add layer normalization
63
60
if add_layer_norm :
64
- linears .append (torch .nn .LayerNorm (int (dim * layer_structure [i + 1 ])))
61
+ linears .append (torch .nn .LayerNorm (int (dim * layer_structure [i + 1 ])))
65
62
66
63
# Add dropout expect last layer
67
64
if use_dropout and i < len (layer_structure ) - 3 :
@@ -130,8 +127,7 @@ class Hypernetwork:
130
127
filename = None
131
128
name = None
132
129
133
- def __init__ (self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , weight_init = None ,
134
- add_layer_norm = False , use_dropout = False ):
130
+ def __init__ (self , name = None , enable_sizes = None , layer_structure = None , activation_func = None , weight_init = None , add_layer_norm = False , use_dropout = False ):
135
131
self .filename = None
136
132
self .name = name
137
133
self .layers = {}
@@ -146,10 +142,8 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
146
142
147
143
for size in enable_sizes or []:
148
144
self .layers [size ] = (
149
- HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init ,
150
- self .add_layer_norm , self .use_dropout ),
151
- HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init ,
152
- self .add_layer_norm , self .use_dropout ),
145
+ HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout ),
146
+ HypernetworkModule (size , None , self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout ),
153
147
)
154
148
155
149
def weights (self ):
@@ -196,15 +190,13 @@ def load(self, filename):
196
190
self .add_layer_norm = state_dict .get ('is_layer_norm' , False )
197
191
print (f"Layer norm is set to { self .add_layer_norm } " )
198
192
self .use_dropout = state_dict .get ('use_dropout' , False )
199
- print (f"Dropout usage is set to { self .use_dropout } " )
193
+ print (f"Dropout usage is set to { self .use_dropout } " )
200
194
201
195
for size , sd in state_dict .items ():
202
196
if type (size ) == int :
203
197
self .layers [size ] = (
204
- HypernetworkModule (size , sd [0 ], self .layer_structure , self .activation_func , self .weight_init ,
205
- self .add_layer_norm , self .use_dropout ),
206
- HypernetworkModule (size , sd [1 ], self .layer_structure , self .activation_func , self .weight_init ,
207
- self .add_layer_norm , self .use_dropout ),
198
+ HypernetworkModule (size , sd [0 ], self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout ),
199
+ HypernetworkModule (size , sd [1 ], self .layer_structure , self .activation_func , self .weight_init , self .add_layer_norm , self .use_dropout ),
208
200
)
209
201
210
202
self .name = state_dict .get ('name' , self .name )
@@ -316,7 +308,7 @@ def statistics(data):
316
308
std = 0
317
309
else :
318
310
std = stdev (data )
319
- total_information = f"loss:{ mean (data ):.3f} " + u"\u00B1 " + f"({ std / (len (data ) ** 0.5 ):.3f} )"
311
+ total_information = f"loss:{ mean (data ):.3f} " + u"\u00B1 " + f"({ std / (len (data ) ** 0.5 ):.3f} )"
320
312
recent_data = data [- 32 :]
321
313
if len (recent_data ) < 2 :
322
314
std = 0
@@ -326,7 +318,7 @@ def statistics(data):
326
318
return total_information , recent_information
327
319
328
320
329
- def report_statistics (loss_info : dict ):
321
+ def report_statistics (loss_info :dict ):
330
322
keys = sorted (loss_info .keys (), key = lambda x : sum (loss_info [x ]) / len (loss_info [x ]))
331
323
for key in keys :
332
324
try :
@@ -338,18 +330,14 @@ def report_statistics(loss_info: dict):
338
330
print (e )
339
331
340
332
341
- def train_hypernetwork (hypernetwork_name , learn_rate , batch_size , data_root , log_directory , training_width ,
342
- training_height , steps , create_image_every , save_hypernetwork_every , template_file ,
343
- preview_from_txt2img , preview_prompt , preview_negative_prompt , preview_steps ,
344
- preview_sampler_index , preview_cfg_scale , preview_seed , preview_width , preview_height ):
333
+
334
+ def train_hypernetwork (hypernetwork_name , learn_rate , batch_size , data_root , log_directory , training_width , training_height , steps , create_image_every , save_hypernetwork_every , template_file , preview_from_txt2img , preview_prompt , preview_negative_prompt , preview_steps , preview_sampler_index , preview_cfg_scale , preview_seed , preview_width , preview_height ):
345
335
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
346
336
from modules import images
347
337
348
338
save_hypernetwork_every = save_hypernetwork_every or 0
349
339
create_image_every = create_image_every or 0
350
- textual_inversion .validate_train_inputs (hypernetwork_name , learn_rate , batch_size , data_root , template_file , steps ,
351
- save_hypernetwork_every , create_image_every , log_directory ,
352
- name = "hypernetwork" )
340
+ textual_inversion .validate_train_inputs (hypernetwork_name , learn_rate , batch_size , data_root , template_file , steps , save_hypernetwork_every , create_image_every , log_directory , name = "hypernetwork" )
353
341
354
342
path = shared .hypernetworks .get (hypernetwork_name , None )
355
343
shared .loaded_hypernetwork = Hypernetwork ()
@@ -384,29 +372,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
384
372
return hypernetwork , filename
385
373
386
374
scheduler = LearnRateScheduler (learn_rate , steps , ititial_step )
387
-
375
+
388
376
# dataset loading may take a while, so input validations and early returns should be done before this
389
377
shared .state .textinfo = f"Preparing dataset from { html .escape (data_root )} ..."
390
378
with torch .autocast ("cuda" ):
391
- ds = modules .textual_inversion .dataset .PersonalizedBase (data_root = data_root , width = training_width ,
392
- height = training_height ,
393
- repeats = shared .opts .training_image_repeats_per_epoch ,
394
- placeholder_token = hypernetwork_name ,
395
- model = shared .sd_model , device = devices .device ,
396
- template_file = template_file , include_cond = True ,
397
- batch_size = batch_size )
379
+ ds = modules .textual_inversion .dataset .PersonalizedBase (data_root = data_root , width = training_width , height = training_height , repeats = shared .opts .training_image_repeats_per_epoch , placeholder_token = hypernetwork_name , model = shared .sd_model , device = devices .device , template_file = template_file , include_cond = True , batch_size = batch_size )
398
380
399
381
if unload :
400
382
shared .sd_model .cond_stage_model .to (devices .cpu )
401
383
shared .sd_model .first_stage_model .to (devices .cpu )
402
384
403
385
size = len (ds .indexes )
404
- loss_dict = defaultdict (lambda : deque (maxlen = 1024 ))
386
+ loss_dict = defaultdict (lambda : deque (maxlen = 1024 ))
405
387
losses = torch .zeros ((size ,))
406
388
previous_mean_losses = [0 ]
407
389
previous_mean_loss = 0
408
390
print ("Mean loss of {} elements" .format (size ))
409
-
391
+
410
392
weights = hypernetwork .weights ()
411
393
for weight in weights :
412
394
weight .requires_grad = True
@@ -425,7 +407,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
425
407
if len (loss_dict ) > 0 :
426
408
previous_mean_losses = [i [- 1 ] for i in loss_dict .values ()]
427
409
previous_mean_loss = mean (previous_mean_losses )
428
-
410
+
429
411
scheduler .apply (optimizer , hypernetwork .step )
430
412
if scheduler .finished :
431
413
break
@@ -444,7 +426,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
444
426
losses [hypernetwork .step % losses .shape [0 ]] = loss .item ()
445
427
for entry in entries :
446
428
loss_dict [entry .filename ].append (loss .item ())
447
-
429
+
448
430
optimizer .zero_grad ()
449
431
weights [0 ].grad = None
450
432
loss .backward ()
@@ -459,9 +441,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
459
441
460
442
steps_done = hypernetwork .step + 1
461
443
462
- if torch .isnan (losses [hypernetwork .step % losses .shape [0 ]]):
444
+ if torch .isnan (losses [hypernetwork .step % losses .shape [0 ]]):
463
445
raise RuntimeError ("Loss diverged." )
464
-
446
+
465
447
if len (previous_mean_losses ) > 1 :
466
448
std = stdev (previous_mean_losses )
467
449
else :
@@ -510,18 +492,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
510
492
preview_text = p .prompt
511
493
512
494
processed = processing .process_images (p )
513
- image = processed .images [0 ] if len (processed .images ) > 0 else None
495
+ image = processed .images [0 ] if len (processed .images )> 0 else None
514
496
515
497
if unload :
516
498
shared .sd_model .cond_stage_model .to (devices .cpu )
517
499
shared .sd_model .first_stage_model .to (devices .cpu )
518
500
519
501
if image is not None :
520
502
shared .state .current_image = image
521
- last_saved_image , last_text_info = images .save_image (image , images_dir , "" , p .seed , p .prompt ,
522
- shared .opts .samples_format , processed .infotexts [0 ],
523
- p = p , forced_filename = forced_filename ,
524
- save_to_dirs = False )
503
+ last_saved_image , last_text_info = images .save_image (image , images_dir , "" , p .seed , p .prompt , shared .opts .samples_format , processed .infotexts [0 ], p = p , forced_filename = forced_filename , save_to_dirs = False )
525
504
last_saved_image += f", prompt: { preview_text } "
526
505
527
506
shared .state .job_no = hypernetwork .step
@@ -535,15 +514,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
535
514
Last saved image: { html .escape (last_saved_image )} <br/>
536
515
</p>
537
516
"""
538
-
517
+
539
518
report_statistics (loss_dict )
540
519
541
520
filename = os .path .join (shared .cmd_opts .hypernetwork_dir , f'{ hypernetwork_name } .pt' )
542
521
save_hypernetwork (hypernetwork , checkpoint , hypernetwork_name , filename )
543
522
544
523
return hypernetwork , filename
545
524
546
-
547
525
def save_hypernetwork (hypernetwork , checkpoint , hypernetwork_name , filename ):
548
526
old_hypernetwork_name = hypernetwork .name
549
527
old_sd_checkpoint = hypernetwork .sd_checkpoint if hasattr (hypernetwork , "sd_checkpoint" ) else None
@@ -557,4 +535,4 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
557
535
hypernetwork .sd_checkpoint = old_sd_checkpoint
558
536
hypernetwork .sd_checkpoint_name = old_sd_checkpoint_name
559
537
hypernetwork .name = old_hypernetwork_name
560
- raise
538
+ raise
0 commit comments