Skip to content

Commit cdee155

Browse files
committed
update train_autoencoderkl.py
1 parent dd3a0a3 commit cdee155

File tree

4 files changed

+51
-26
lines changed

4 files changed

+51
-26
lines changed

examples/autoencoderkl/README.md renamed to examples/research_projects/autoencoderkl/README.md

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,26 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e
2525
accelerate config
2626
```
2727

28+
## Training on CIFAR10
29+
30+
```bash
31+
accelerate launch train_autoencoderkl.py \
32+
--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
33+
--dataset_name=cifar10 \
34+
--image_column=img \
35+
--validation_image images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \
36+
--num_train_epochs 100 \
37+
--gradient_accumulation_steps 2 \
38+
--learning_rate 4.5e-6 \
39+
--lr_scheduler cosine \
40+
--report_to wandb \
41+
```
42+
2843
## Training on ImageNet
2944

3045
```bash
31-
accelerate launch --multi_gpu --num_processes 4 --mixed_precision bf16 train_autoencoderkl.py \
32-
--pretrained_model_name_or_path stabilityai/sdxl-vae \
46+
accelerate launch train_autoencoderkl.py \
47+
--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
3348
--num_train_epochs 100 \
3449
--gradient_accumulation_steps 2 \
3550
--learning_rate 4.5e-6 \

examples/autoencoderkl/train_autoencoderkl.py renamed to examples/research_projects/autoencoderkl/train_autoencoderkl.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
116
import argparse
217
import contextlib
318
import gc
@@ -33,7 +48,7 @@
3348
from diffusers import AutoencoderKL
3449
from diffusers.optimization import get_scheduler
3550
from diffusers.training_utils import EMAModel
36-
from diffusers.utils import check_min_version, is_wandb_available
51+
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
3752
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
3853
from diffusers.utils.import_utils import is_xformers_available
3954
from diffusers.utils.torch_utils import is_compiled_module
@@ -43,22 +58,11 @@
4358
import wandb
4459

4560
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
46-
check_min_version("0.30.0.dev0")
61+
# check_min_version("0.33.0.dev0")
4762

4863
logger = get_logger(__name__)
4964

5065

51-
def image_grid(imgs, rows, cols):
52-
assert len(imgs) == rows * cols
53-
54-
w, h = imgs[0].size
55-
grid = Image.new("RGB", size=(cols * w, rows * h))
56-
57-
for i, img in enumerate(imgs):
58-
grid.paste(img, box=(i % cols * w, i // cols * h))
59-
return grid
60-
61-
6266
@torch.no_grad()
6367
def log_validation(
6468
vae, args, accelerator, weight_dtype, step, is_final_validation=False
@@ -111,7 +115,7 @@ def log_validation(
111115
}
112116
)
113117
else:
114-
logger.warn(f"image logging not implemented for {tracker.gen_images}")
118+
logger.warn(f"image logging not implemented for {tracker.name}")
115119

116120
gc.collect()
117121
torch.cuda.empty_cache()
@@ -123,7 +127,7 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None)
123127
img_str = ""
124128
if images is not None:
125129
img_str = "You can find some example images below.\n\n"
126-
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images.png"))
130+
make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images.png"))
127131
img_str += f"![images](./images.png)\n"
128132

129133
model_description = f"""
@@ -875,23 +879,19 @@ def load_model_hook(models, input_dir):
875879
for step, batch in enumerate(train_dataloader):
876880
# Convert images to latent space and reconstruct from them
877881
targets = batch["pixel_values"].to(dtype=weight_dtype)
878-
if accelerator.num_processes > 1:
879-
posterior = vae.module.encode(targets).latent_dist
880-
else:
881-
posterior = vae.encode(targets).latent_dist
882+
posterior = accelerator.unwrap_model(vae).encode(targets).latent_dist
882883
latents = posterior.sample()
883-
if accelerator.num_processes > 1:
884-
reconstructions = vae.module.decode(latents).sample
885-
else:
886-
reconstructions = vae.decode(latents).sample
884+
reconstructions = accelerator.unwrap_model(vae).decode(latents).sample
887885

888886
if (step // args.gradient_accumulation_steps) % 2 == 0 or global_step < args.disc_start:
889887
with accelerator.accumulate(vae):
890888
# reconstruction loss. Pixel level differences between input vs output
891889
if args.rec_loss == "l2":
892890
rec_loss = F.mse_loss(reconstructions.float(), targets.float(), reduction="none")
893-
else:
891+
elif args.rec_loss == "l1":
894892
rec_loss = F.l1_loss(reconstructions.float(), targets.float(), reduction="none")
893+
else:
894+
raise ValueError(f"Invalid reconstruction loss type: {args.rec_loss}")
895895
# perceptual loss. The high level feature mean squared error loss
896896
with torch.no_grad():
897897
p_loss = perceptual_loss(reconstructions, targets)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
accelerate launch train_autoencoderkl.py \
2+
--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
3+
--dataset_name=cifar10 \
4+
--image_column=img \
5+
--validation_image /home/azureuser/v-yuqianhong/ImageNet/ILSVRC2012/val/n01491361/ILSVRC2012_val_00002922.JPEG \
6+
--num_train_epochs 100 \
7+
--gradient_accumulation_steps 2 \
8+
--learning_rate 4.5e-6 \
9+
--lr_scheduler cosine \
10+
--report_to wandb \

0 commit comments

Comments
 (0)