Skip to content

Commit dd3a0a3

Browse files
committed
update main.py
1 parent cf94b6a commit dd3a0a3

File tree

1 file changed

+104
-33
lines changed

1 file changed

+104
-33
lines changed

examples/autoencoderkl/train_autoencoderkl.py

Lines changed: 104 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@
2323
from huggingface_hub import create_repo, upload_folder
2424
from packaging import version
2525
from PIL import Image
26-
from taming.modules.losses.vqperceptual import *
26+
from taming.modules.losses.vqperceptual import (
27+
hinge_d_loss, vanilla_d_loss, weights_init, NLayerDiscriminator
28+
)
2729
from torchvision import transforms
2830
from tqdm.auto import tqdm
2931

3032
import diffusers
3133
from diffusers import AutoencoderKL
3234
from diffusers.optimization import get_scheduler
35+
from diffusers.training_utils import EMAModel
3336
from diffusers.utils import check_min_version, is_wandb_available
3437
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
3538
from diffusers.utils.import_utils import is_xformers_available
@@ -56,6 +59,7 @@ def image_grid(imgs, rows, cols):
5659
return grid
5760

5861

62+
@torch.no_grad()
5963
def log_validation(
6064
vae, args, accelerator, weight_dtype, step, is_final_validation=False
6165
):
@@ -80,8 +84,8 @@ def log_validation(
8084

8185
for i, validation_image in enumerate(args.validation_image):
8286
validation_image = Image.open(validation_image).convert("RGB")
83-
targets = image_transforms(validation_image).to(weight_dtype)
84-
targets = targets.unsqueeze(0).to(vae.device)
87+
targets = image_transforms(validation_image).to(accelerator.device, weight_dtype)
88+
targets = targets.unsqueeze(0)
8589

8690
with inference_ctx:
8791
reconstructions = vae(targets).sample
@@ -112,15 +116,15 @@ def log_validation(
112116
gc.collect()
113117
torch.cuda.empty_cache()
114118

115-
return images
119+
return images
116120

117121

118122
def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):
119123
img_str = ""
120124
if images is not None:
121125
img_str = "You can find some example images below.\n\n"
122-
image_grid(images, 1, "example").save(os.path.join(repo_folder, f"images_{i}.png"))
123-
img_str += f"![images_{i})](./images_{i}.png)\n"
126+
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images.png"))
127+
img_str += f"![images](./images.png)\n"
124128

125129
model_description = f"""
126130
# autoencoderkl-{repo_id}
@@ -156,9 +160,14 @@ def parse_args(input_args=None):
156160
"--pretrained_model_name_or_path",
157161
type=str,
158162
default=None,
159-
required=True,
160163
help="Path to pretrained model or model identifier from huggingface.co/models.",
161164
)
165+
parser.add_argument(
166+
"--model_config_name_or_path",
167+
type=str,
168+
default=None,
169+
help="The config of the VAE model to train, leave as None to use standard VAE model configuration.",
170+
)
162171
parser.add_argument(
163172
"--revision",
164173
type=str,
@@ -242,6 +251,12 @@ def parse_args(input_args=None):
242251
default=4.5e-6,
243252
help="Initial learning rate (after the potential warmup period) to use.",
244253
)
254+
parser.add_argument(
255+
"--disc_learning_rate",
256+
type=float,
257+
default=4.5e-6,
258+
help="Initial learning rate (after the potential warmup period) to use.",
259+
)
245260
parser.add_argument(
246261
"--scale_lr",
247262
action="store_true",
@@ -257,6 +272,15 @@ def parse_args(input_args=None):
257272
' "constant", "constant_with_warmup"]'
258273
),
259274
)
275+
parser.add_argument(
276+
"--disc_lr_scheduler",
277+
type=str,
278+
default="constant",
279+
help=(
280+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
281+
' "constant", "constant_with_warmup"]'
282+
),
283+
)
260284
parser.add_argument(
261285
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
262286
)
@@ -270,6 +294,7 @@ def parse_args(input_args=None):
270294
parser.add_argument(
271295
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
272296
)
297+
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
273298
parser.add_argument(
274299
"--dataloader_num_workers",
275300
type=int,
@@ -417,7 +442,7 @@ def parse_args(input_args=None):
417442
help="Scaling factor for the Kullback-Leibler divergence penalty term.",
418443
)
419444
parser.add_argument(
420-
"--lpips_scale",
445+
"--perceptual_scale",
421446
type=float,
422447
default=0.5,
423448
help="Scaling factor for the LPIPS metric",
@@ -440,6 +465,12 @@ def parse_args(input_args=None):
440465
default=1.0,
441466
help="Scaling factor for the discriminator",
442467
)
468+
parser.add_argument(
469+
"--disc_loss",
470+
type=str,
471+
default="hinge",
472+
help="Loss function for the discriminator",
473+
)
443474
parser.add_argument(
444475
"--decoder_only",
445476
action="store_true",
@@ -587,19 +618,28 @@ def main(args):
587618
).repo_id
588619

589620
# Load AutoencoderKL
590-
vae = AutoencoderKL.from_pretrained(
591-
args.pretrained_model_name_or_path, revision=args.revision
592-
)
593-
lpips_loss_fn = lpips.LPIPS(net="vgg")
594-
discriminator = NLayerDiscriminator(
595-
input_nc=3, n_layers=3, use_actnorm=False,
596-
).apply(weights_init)
621+
if args.pretrained_model_name_or_path is None and args.model_config_name_or_path is None:
622+
config = AutoencoderKL.load_config("stabilityai/sd-vae-ft-mse")
623+
vae = AutoencoderKL.from_config(config)
624+
elif args.pretrained_model_name_or_path is not None:
625+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, revision=args.revision)
626+
else:
627+
config = AutoencoderKL.load_config(args.model_config_name_or_path)
628+
vae = AutoencoderKL.from_config(config)
629+
if args.use_ema:
630+
ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)
631+
perceptual_loss = lpips.LPIPS(net="vgg").eval()
632+
discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
597633

598634
# `accelerate` 0.16.0 will have better support for customized saving
599635
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
600636
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
601637
def save_model_hook(models, weights, output_dir):
602638
if accelerator.is_main_process:
639+
if args.use_ema:
640+
sub_dir = "autoencoderkl_ema"
641+
ema_vae.save_pretrained(os.path.join(output_dir, sub_dir))
642+
603643
i = len(weights) - 1
604644

605645
while len(weights) > 0:
@@ -618,13 +658,22 @@ def save_model_hook(models, weights, output_dir):
618658

619659
def load_model_hook(models, input_dir):
620660
while len(models) > 0:
661+
if args.use_ema:
662+
sub_dir = "autoencoderkl_ema"
663+
load_model = EMAModel.from_pretrained(os.path.join(input_dir, sub_dir), AutoencoderKL)
664+
ema_vae.load_state_dict(load_model.state_dict())
665+
ema_vae.to(accelerator.device)
666+
del load_model
667+
621668
# pop models so that they are not loaded again
622669
model = models.pop()
623-
624-
# load diffusers style into model
670+
load_model = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).load_state_dict(os.path.join(input_dir, "discriminator", "pytorch_model.bin"))
671+
model.load_state_dict(load_model.state_dict())
672+
del load_model
673+
674+
model = models.pop()
625675
load_model = AutoencoderKL.from_pretrained(input_dir, subfolder="autoencoderkl")
626676
model.register_to_config(**load_model.config)
627-
628677
model.load_state_dict(load_model.state_dict())
629678
del load_model
630679

@@ -638,7 +687,6 @@ def load_model_hook(models, input_dir):
638687
if getattr(vae, "quant_conv", None):
639688
vae.quant_conv.requires_grad_(False)
640689
vae.train()
641-
lpips_loss_fn.requires_grad_(False)
642690
discriminator.requires_grad_(True)
643691
discriminator.train()
644692

@@ -688,17 +736,17 @@ def load_model_hook(models, input_dir):
688736
optimizer_class = torch.optim.AdamW
689737

690738
params_to_optimize = filter(lambda p: p.requires_grad, vae.parameters())
691-
params_to_optimize_2 = filter(lambda p: p.requires_grad, discriminator.parameters())
739+
disc_params_to_optimize = filter(lambda p: p.requires_grad, discriminator.parameters())
692740
optimizer = optimizer_class(
693741
params_to_optimize,
694742
lr=args.learning_rate,
695743
betas=(args.adam_beta1, args.adam_beta2),
696744
weight_decay=args.adam_weight_decay,
697745
eps=args.adam_epsilon,
698746
)
699-
optimizer_2 = optimizer_class(
700-
params_to_optimize_2,
701-
lr=args.learning_rate,
747+
disc_optimizer = optimizer_class(
748+
disc_params_to_optimize,
749+
lr=args.disc_learning_rate,
702750
betas=(args.adam_beta1, args.adam_beta2),
703751
weight_decay=args.adam_weight_decay,
704752
eps=args.adam_epsilon,
@@ -729,10 +777,18 @@ def load_model_hook(models, input_dir):
729777
num_cycles=args.lr_num_cycles,
730778
power=args.lr_power,
731779
)
780+
disc_lr_scheduler = get_scheduler(
781+
args.disc_lr_scheduler,
782+
optimizer=disc_optimizer,
783+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
784+
num_training_steps=args.max_train_steps * accelerator.num_processes,
785+
num_cycles=args.lr_num_cycles,
786+
power=args.lr_power,
787+
)
732788

733789
# Prepare everything with our `accelerator`.
734-
vae, discriminator, optimizer, optimizer_2, train_dataloader, lr_scheduler = accelerator.prepare(
735-
vae, discriminator, optimizer, optimizer_2, train_dataloader, lr_scheduler
790+
vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler = accelerator.prepare(
791+
vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler
736792
)
737793

738794
# For mixed precision training we cast the text_encoder and vae weights to half-precision
@@ -743,10 +799,12 @@ def load_model_hook(models, input_dir):
743799
elif accelerator.mixed_precision == "bf16":
744800
weight_dtype = torch.bfloat16
745801

746-
# Move vae to device and cast to weight_dtype
802+
# Move VAE, perceptual loss and discriminator to device and cast to weight_dtype
747803
vae.to(accelerator.device, dtype=weight_dtype)
748-
lpips_loss_fn.to(accelerator.device, dtype=weight_dtype)
804+
perceptual_loss.to(accelerator.device, dtype=weight_dtype)
749805
discriminator.to(accelerator.device, dtype=weight_dtype)
806+
if args.use_ema:
807+
ema_vae.to(accelerator.device, dtype=weight_dtype)
750808

751809
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
752810
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -812,6 +870,8 @@ def load_model_hook(models, input_dir):
812870

813871
image_logs = None
814872
for epoch in range(first_epoch, args.num_train_epochs):
873+
vae.train()
874+
discriminator.train()
815875
for step, batch in enumerate(train_dataloader):
816876
# Convert images to latent space and reconstruct from them
817877
targets = batch["pixel_values"].to(dtype=weight_dtype)
@@ -834,9 +894,9 @@ def load_model_hook(models, input_dir):
834894
rec_loss = F.l1_loss(reconstructions.float(), targets.float(), reduction="none")
835895
# perceptual loss. The high level feature mean squared error loss
836896
with torch.no_grad():
837-
lpips_loss = lpips_loss_fn(reconstructions, targets)
897+
p_loss = perceptual_loss(reconstructions, targets)
838898

839-
rec_loss = rec_loss + args.lpips_scale * lpips_loss
899+
rec_loss = rec_loss + args.perceptual_scale * p_loss
840900
nll_loss = rec_loss
841901
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
842902

@@ -859,10 +919,10 @@ def load_model_hook(models, input_dir):
859919
"loss": loss.detach().mean().item(),
860920
"nll_loss": nll_loss.detach().mean().item(),
861921
"rec_loss": rec_loss.detach().mean().item(),
862-
"lpips_loss": lpips_loss.detach().mean().item(),
922+
"p_loss": p_loss.detach().mean().item(),
863923
"kl_loss": kl_loss.detach().mean().item(),
864924
"disc_weight": disc_weight.detach().mean().item(),
865-
"disc_factor": torch.tensor(disc_factor),
925+
"disc_factor": disc_factor,
866926
"g_loss": g_loss.detach().mean().item(),
867927
"lr": lr_scheduler.get_last_lr()[0]
868928
}
@@ -878,18 +938,21 @@ def load_model_hook(models, input_dir):
878938
with accelerator.accumulate(discriminator):
879939
logits_real = discriminator(targets)
880940
logits_fake = discriminator(reconstructions)
881-
disc_loss = hinge_d_loss
941+
disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss
882942
disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
883943
disc_loss = disc_factor * disc_loss(logits_real, logits_fake)
884944
logs = {
885945
"disc_loss": disc_loss.detach().mean().item(),
886946
"logits_real": logits_real.detach().mean().item(),
887947
"logits_fake": logits_fake.detach().mean().item(),
948+
"disc_lr": disc_lr_scheduler.get_last_lr()[0]
888949
}
889950
# Checks if the accelerator has performed an optimization step behind the scenes
890951
if accelerator.sync_gradients:
891952
progress_bar.update(1)
892953
global_step += 1
954+
if args.use_ema:
955+
ema_vae.step(vae.parameters())
893956

894957
if accelerator.is_main_process:
895958
if global_step % args.checkpointing_steps == 0:
@@ -918,13 +981,18 @@ def load_model_hook(models, input_dir):
918981
logger.info(f"Saved state to {save_path}")
919982

920983
if global_step == 1 or global_step % args.validation_steps == 0:
984+
if args.use_ema:
985+
ema_vae.store(vae.parameters())
986+
ema_vae.copy_to(vae.parameters())
921987
image_logs = log_validation(
922988
vae,
923989
args,
924990
accelerator,
925991
weight_dtype,
926992
global_step,
927993
)
994+
if args.use_ema:
995+
ema_vae.restore(vae.parameters())
928996

929997
progress_bar.set_postfix(**logs)
930998
accelerator.log(logs, step=global_step)
@@ -936,8 +1004,11 @@ def load_model_hook(models, input_dir):
9361004
accelerator.wait_for_everyone()
9371005
if accelerator.is_main_process:
9381006
vae = accelerator.unwrap_model(vae)
1007+
discriminator = accelerator.unwrap_model(discriminator)
1008+
if args.use_ema:
1009+
ema_vae.copy_to(vae.parameters())
9391010
vae.save_pretrained(args.output_dir)
940-
1011+
torch.save(discriminator.state_dict(), os.path.join(args.output_dir, "pytorch_model.bin"))
9411012
# Run a final round of validation.
9421013
image_logs = None
9431014
image_logs = log_validation(

0 commit comments

Comments
 (0)