Skip to content

Commit 4fad71c

Browse files
authored
Training optimizations (#217)
* Optimizations to the training model Based on the changes made in textual_inversion I carried over the relevant changes that improve model training. These changes reduce the amount of memory used, significantly improve the speed at which training runs, and improves the quality of the results. It also fixes the problem where the model trainer wouldn't automatically stop when it hit the set number of steps. * Update main.py Cleaned up whitespace
1 parent d126db2 commit 4fad71c

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

configs/stable-diffusion/v1-finetune.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ model:
5252
ddconfig:
5353
double_z: true
5454
z_channels: 4
55-
resolution: 256
55+
resolution: 512
5656
in_channels: 3
5757
out_ch: 3
5858
ch: 128
@@ -73,7 +73,7 @@ model:
7373
data:
7474
target: main.DataModuleFromConfig
7575
params:
76-
batch_size: 2
76+
batch_size: 1
7777
num_workers: 16
7878
wrap: false
7979
train:
@@ -92,6 +92,9 @@ data:
9292
repeats: 10
9393

9494
lightning:
95+
modelcheckpoint:
96+
params:
97+
every_n_train_steps: 500
9598
callbacks:
9699
image_logger:
97100
target: main.ImageLogger

main.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ def str2bool(v):
171171
help='Initialize embedding manager from a checkpoint',
172172
)
173173
parser.add_argument(
174-
'--placeholder_tokens', type=str, nargs='+', default=['*']
175-
)
174+
'--placeholder_tokens', type=str, nargs='+', default=['*'],
175+
help='Placeholder token which will be used to denote the concept in future prompts')
176176

177177
parser.add_argument(
178178
'--init_word',
@@ -473,7 +473,7 @@ def log_img(self, pl_module, batch, batch_idx, split='train'):
473473
self.check_frequency(check_idx)
474474
and hasattr( # batch_idx % self.batch_freq == 0
475475
pl_module, 'log_images'
476-
)
476+
)
477477
and callable(pl_module.log_images)
478478
and self.max_images > 0
479479
):
@@ -569,6 +569,21 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
569569
except AttributeError:
570570
pass
571571

572+
class ModeSwapCallback(Callback):
573+
574+
def __init__(self, swap_step=2000):
575+
super().__init__()
576+
self.is_frozen = False
577+
self.swap_step = swap_step
578+
579+
def on_train_epoch_start(self, trainer, pl_module):
580+
if trainer.global_step < self.swap_step and not self.is_frozen:
581+
self.is_frozen = True
582+
trainer.optimizers = [pl_module.configure_opt_embedding()]
583+
584+
if trainer.global_step > self.swap_step and self.is_frozen:
585+
self.is_frozen = False
586+
trainer.optimizers = [pl_module.configure_opt_model()]
572587

573588
if __name__ == '__main__':
574589
# custom parser to specify config files, train, test and debug mode,
@@ -663,6 +678,7 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
663678
if opt.datadir_in_name:
664679
now = os.path.basename(os.path.normpath(opt.data_root)) + now
665680

681+
666682
nowname = now + name + opt.postfix
667683
logdir = os.path.join(opt.logdir, nowname)
668684

@@ -756,7 +772,7 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
756772
if hasattr(model, 'monitor'):
757773
print(f'Monitoring {model.monitor} as checkpoint metric.')
758774
default_modelckpt_cfg['params']['monitor'] = model.monitor
759-
default_modelckpt_cfg['params']['save_top_k'] = 3
775+
default_modelckpt_cfg['params']['save_top_k'] = 1
760776

761777
if 'modelcheckpoint' in lightning_config:
762778
modelckpt_cfg = lightning_config.modelcheckpoint
@@ -846,7 +862,7 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
846862
trainer_kwargs['callbacks'] = [
847863
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
848864
]
849-
trainer_kwargs['max_steps'] = opt.max_steps
865+
trainer_kwargs['max_steps'] = trainer_opt.max_steps
850866

851867
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
852868
trainer.logdir = logdir ###

0 commit comments

Comments
 (0)