Skip to content

Commit ca62ba8

Browse files
committed
update autoencoderkl.py
1 parent ceec5f3 commit ca62ba8

File tree

1 file changed

+49
-41
lines changed

1 file changed

+49
-41
lines changed

examples/research_projects/autoencoderkl/train_autoencoderkl.py

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,12 @@
1919
import logging
2020
import math
2121
import os
22-
import random
2322
import shutil
2423
from pathlib import Path
2524

2625
import accelerate
27-
import numpy as np
2826
import lpips
27+
import numpy as np
2928
import torch
3029
import torch.nn.functional as F
3130
import torch.utils.checkpoint
@@ -38,9 +37,7 @@
3837
from huggingface_hub import create_repo, upload_folder
3938
from packaging import version
4039
from PIL import Image
41-
from taming.modules.losses.vqperceptual import (
42-
hinge_d_loss, vanilla_d_loss, weights_init, NLayerDiscriminator
43-
)
40+
from taming.modules.losses.vqperceptual import NLayerDiscriminator, hinge_d_loss, vanilla_d_loss, weights_init
4441
from torchvision import transforms
4542
from tqdm.auto import tqdm
4643

@@ -93,22 +90,22 @@ def log_validation(
9390

9491
with inference_ctx:
9592
reconstructions = vae(targets).sample
96-
93+
9794
images.append(
9895
torch.cat([targets.cpu(), reconstructions.cpu()], axis=0)
9996
)
100-
97+
10198
tracker_key = "test" if is_final_validation else "validation"
10299
for tracker in accelerator.trackers:
103100
if tracker.name == "tensorboard":
104101
np_images = np.stack([np.asarray(img) for img in images])
105102
tracker.writer.add_images(
106-
"Original (left), Reconstruction (right)", np_images, step
103+
f"{tracker_key}: Original (left), Reconstruction (right)", np_images, step
107104
)
108105
elif tracker.name == "wandb":
109106
tracker.log(
110107
{
111-
"Original (left), Reconstruction (right)": [
108+
f"{tracker_key}: Original (left), Reconstruction (right)": [
112109
wandb.Image(torchvision.utils.make_grid(image))
113110
for _, image in enumerate(images)
114111
]
@@ -127,8 +124,8 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None)
127124
img_str = ""
128125
if images is not None:
129126
img_str = "You can find some example images below.\n\n"
130-
make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images.png"))
131-
img_str += f"![images](./images.png)\n"
127+
make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, "images.png"))
128+
img_str += "![images](./images.png)\n"
132129

133130
model_description = f"""
134131
# autoencoderkl-{repo_id}
@@ -529,7 +526,7 @@ def make_train_dataset(args, accelerator):
529526
# Preprocessing the datasets.
530527
# We need to tokenize inputs and targets.
531528
column_names = dataset["train"].column_names
532-
529+
533530
# 6. Get the column names for input/target.
534531
if args.image_column is None:
535532
image_column = column_names[0]
@@ -540,7 +537,7 @@ def make_train_dataset(args, accelerator):
540537
raise ValueError(
541538
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
542539
)
543-
540+
544541
image_transforms = transforms.Compose(
545542
[
546543
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
@@ -580,7 +577,7 @@ def main(args):
580577
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
581578
" Please use `huggingface-cli login` to authenticate with the Hub."
582579
)
583-
580+
584581
logging_dir = Path(args.output_dir, args.logging_dir)
585582

586583
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
@@ -591,7 +588,7 @@ def main(args):
591588
log_with=args.report_to,
592589
project_config=accelerator_project_config,
593590
)
594-
591+
595592
# Disable AMP for MPS.
596593
if torch.backends.mps.is_available():
597594
accelerator.native_amp = False
@@ -623,7 +620,7 @@ def main(args):
623620
repo_id = create_repo(
624621
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
625622
).repo_id
626-
623+
627624
# Load AutoencoderKL
628625
if args.pretrained_model_name_or_path is None and args.model_config_name_or_path is None:
629626
config = AutoencoderKL.load_config("stabilityai/sd-vae-ft-mse")
@@ -637,7 +634,13 @@ def main(args):
637634
ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)
638635
perceptual_loss = lpips.LPIPS(net="vgg").eval()
639636
discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
640-
637+
638+
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
639+
def unwrap_model(model):
640+
model = accelerator.unwrap_model(model)
641+
model = model._orig_mod if is_compiled_module(model) else model
642+
return model
643+
641644
# `accelerate` 0.16.0 will have better support for customized saving
642645
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
643646
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -677,7 +680,7 @@ def load_model_hook(models, input_dir):
677680
load_model = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).load_state_dict(os.path.join(input_dir, "discriminator", "pytorch_model.bin"))
678681
model.load_state_dict(load_model.state_dict())
679682
del load_model
680-
683+
681684
model = models.pop()
682685
load_model = AutoencoderKL.from_pretrained(input_dir, subfolder="autoencoderkl")
683686
model.register_to_config(**load_model.config)
@@ -686,8 +689,8 @@ def load_model_hook(models, input_dir):
686689

687690
accelerator.register_save_state_pre_hook(save_model_hook)
688691
accelerator.register_load_state_pre_hook(load_model_hook)
689-
690-
692+
693+
691694
vae.requires_grad_(True)
692695
if args.decoder_only:
693696
vae.encoder.requires_grad_(False)
@@ -696,7 +699,7 @@ def load_model_hook(models, input_dir):
696699
vae.train()
697700
discriminator.requires_grad_(True)
698701
discriminator.train()
699-
702+
700703
if args.enable_xformers_memory_efficient_attention:
701704
if is_xformers_available():
702705
import xformers
@@ -709,16 +712,21 @@ def load_model_hook(models, input_dir):
709712
vae.enable_xformers_memory_efficient_attention()
710713
else:
711714
raise ValueError("xformers is not available. Make sure it is installed correctly")
712-
715+
713716
if args.gradient_checkpointing:
714717
vae.enable_gradient_checkpointing()
715-
718+
716719
# Check that all trainable models are in full precision
717720
low_precision_error_string = (
718721
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
719722
" doing mixed precision training, copy of the weights should still be float32."
720723
)
721-
724+
725+
if unwrap_model(vae).dtype != torch.float32:
726+
raise ValueError(
727+
f"VAE loaded as datatype {unwrap_model(vae).dtype}. {low_precision_error_string}"
728+
)
729+
722730
# Enable TF32 for faster training on Ampere GPUs,
723731
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
724732
if args.allow_tf32:
@@ -728,7 +736,7 @@ def load_model_hook(models, input_dir):
728736
args.learning_rate = (
729737
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
730738
)
731-
739+
732740
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
733741
if args.use_8bit_adam:
734742
try:
@@ -741,7 +749,7 @@ def load_model_hook(models, input_dir):
741749
optimizer_class = bnb.optim.AdamW8bit
742750
else:
743751
optimizer_class = torch.optim.AdamW
744-
752+
745753
params_to_optimize = filter(lambda p: p.requires_grad, vae.parameters())
746754
disc_params_to_optimize = filter(lambda p: p.requires_grad, discriminator.parameters())
747755
optimizer = optimizer_class(
@@ -760,22 +768,22 @@ def load_model_hook(models, input_dir):
760768
)
761769

762770
train_dataset = make_train_dataset(args, accelerator)
763-
771+
764772
train_dataloader = torch.utils.data.DataLoader(
765773
train_dataset,
766774
shuffle=True,
767775
collate_fn=collate_fn,
768776
batch_size=args.train_batch_size,
769777
num_workers=args.dataloader_num_workers,
770778
)
771-
779+
772780
# Scheduler and math around the number of training steps.
773781
overrode_max_train_steps = False
774782
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
775783
if args.max_train_steps is None:
776784
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
777785
overrode_max_train_steps = True
778-
786+
779787
lr_scheduler = get_scheduler(
780788
args.lr_scheduler,
781789
optimizer=optimizer,
@@ -792,27 +800,27 @@ def load_model_hook(models, input_dir):
792800
num_cycles=args.lr_num_cycles,
793801
power=args.lr_power,
794802
)
795-
803+
796804
# Prepare everything with our `accelerator`.
797805
vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler = accelerator.prepare(
798806
vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler
799807
)
800-
808+
801809
# For mixed precision training we cast the text_encoder and vae weights to half-precision
802810
# as these models are only used for inference, keeping weights in full precision is not required.
803811
weight_dtype = torch.float32
804812
if accelerator.mixed_precision == "fp16":
805813
weight_dtype = torch.float16
806814
elif accelerator.mixed_precision == "bf16":
807815
weight_dtype = torch.bfloat16
808-
816+
809817
# Move VAE, perceptual loss and discriminator to device and cast to weight_dtype
810818
vae.to(accelerator.device, dtype=weight_dtype)
811819
perceptual_loss.to(accelerator.device, dtype=weight_dtype)
812820
discriminator.to(accelerator.device, dtype=weight_dtype)
813821
if args.use_ema:
814822
ema_vae.to(accelerator.device, dtype=weight_dtype)
815-
823+
816824
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
817825
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
818826
if overrode_max_train_steps:
@@ -850,7 +858,7 @@ def load_model_hook(models, input_dir):
850858
dirs = [d for d in dirs if d.startswith("checkpoint")]
851859
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
852860
path = dirs[-1] if len(dirs) > 0 else None
853-
861+
854862
if path is None:
855863
accelerator.print(
856864
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
@@ -866,7 +874,7 @@ def load_model_hook(models, input_dir):
866874
first_epoch = global_step // num_update_steps_per_epoch
867875
else:
868876
initial_global_step = 0
869-
877+
870878
progress_bar = tqdm(
871879
range(0, args.max_train_steps),
872880
initial=initial_global_step,
@@ -898,7 +906,7 @@ def load_model_hook(models, input_dir):
898906
# perceptual loss. The high level feature mean squared error loss
899907
with torch.no_grad():
900908
p_loss = perceptual_loss(reconstructions, targets)
901-
909+
902910
rec_loss = rec_loss + args.perceptual_scale * p_loss
903911
nll_loss = rec_loss
904912
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
@@ -915,9 +923,9 @@ def load_model_hook(models, input_dir):
915923
disc_weight = torch.clamp(disc_weight, 0.0, 1e4).detach()
916924
disc_weight = disc_weight * args.disc_scale
917925
disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
918-
926+
919927
loss = nll_loss + args.kl_scale * kl_loss + disc_weight * disc_factor * g_loss
920-
928+
921929
logs = {
922930
"loss": loss.detach().mean().item(),
923931
"nll_loss": nll_loss.detach().mean().item(),
@@ -929,7 +937,7 @@ def load_model_hook(models, input_dir):
929937
"g_loss": g_loss.detach().mean().item(),
930938
"lr": lr_scheduler.get_last_lr()[0]
931939
}
932-
940+
933941
accelerator.backward(loss)
934942
if accelerator.sync_gradients:
935943
params_to_clip = vae.parameters()
@@ -1002,7 +1010,7 @@ def load_model_hook(models, input_dir):
10021010

10031011
if global_step >= args.max_train_steps:
10041012
break
1005-
1013+
10061014
# Create the pipeline using using the trained modules and save it.
10071015
accelerator.wait_for_everyone()
10081016
if accelerator.is_main_process:
@@ -1036,7 +1044,7 @@ def load_model_hook(models, input_dir):
10361044
commit_message="End of training",
10371045
ignore_patterns=["step_*", "epoch_*"],
10381046
)
1039-
1047+
10401048
accelerator.end_training()
10411049

10421050

0 commit comments

Comments
 (0)