Skip to content

Commit 6683f97

Browse files
sayakpaulpatil-surajyounesbelkadafabioriganoyiyixuxu
authored
[Training] Add datasets version of LCM LoRA SDXL (#5778)
* add: script to train lcm lora for sdxl with 🤗 datasets * suit up the args. * remove comments. * fix num_update_steps * fix batch unmarshalling * fix num_update_steps_per_epoch * fix; dataloading. * fix microconditions. * unconditional predictions debug * fix batch size. * no need to use use_auth_token * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * make vae encoding batch size an arg * final serialization in kohya * style * state dict rejigging * feat: no separate teacher unet. * debug * fix state dict serialization * debug * debug * debug * remove prints. * remove kohya utility and make style * fix serialization * fix * add test * add peft dependency. * add: peft * remove peft * autocast device determination from accelerator * autocast * reduce lora rank. * remove unneeded space * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * style * remove prompt dropout. * also save in native diffusers ckpt format. * debug * debug * debug * better formation of the null embeddings. * remove space. * autocast fixes. * autocast fix. * hacky * remove lora_sayak * Apply suggestions from code review Co-authored-by: Younes Belkada <[email protected]> * style * make log validation leaner. * move back enabled in. * fix: log_validation call. * add: checkpointing tests * taking my chances to see if disabling autocasting has any effect? * start debugging * name * name * name * more debug * more debug * index * remove index. * print length * print length * print length * move unet.train() after add_adapter() * disable some prints. * enable_adapters() manually. * remove prints. * some changes. * fix params_to_optimize * more fixes * debug * debug * remove print * disable grad for certain contexts. * Add support for IPAdapterFull (#5911) * Add support for IPAdapterFull Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * Fix a bug in `add_noise` function (#6085) * fix * copies --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> * [Advanced Diffusion Script] Add Widget default text (#6100) add widget * [Advanced Training Script] Fix pipe example (#6106) * IP-Adapter for StableDiffusionControlNetImg2ImgPipeline (#5901) * adapter for StableDiffusionControlNetImg2ImgPipeline * fix-copies * fix-copies --------- Co-authored-by: Sayak Paul <[email protected]> * IP adapter support for most pipelines (#5900) * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py * update tests * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py * support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py * support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py * revert changes to sd_attend_and_excite and sd_upscale * make style * fix broken tests * update ip-adapter implementation to latest * apply suggestions from review --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Sayak Paul <[email protected]> * fix: lora_alpha * make vae casting conditional/ * param upcasting * propagate comments from #6145 Co-authored-by: dg845 <[email protected]> * [Peft] fix saving / loading when unet is not "unet" (#6046) * [Peft] fix saving / loading when unet is not "unet" * Update src/diffusers/loaders/lora.py Co-authored-by: Sayak Paul <[email protected]> * undo stablediffusion-xl changes * use unet_name to get unet for lora helpers * use unet_name --------- Co-authored-by: Sayak Paul <[email protected]> * [Wuerstchen] fix fp16 training and correct lora args (#6245) fix fp16 training Co-authored-by: Sayak Paul <[email protected]> * [docs] fix: animatediff docs (#6339) fix: animatediff docs * add: note about the new script in readme_sdxl. * Revert "[Peft] fix saving / loading when unet is not "unet" (#6046)" This reverts commit 4c7e983. * Revert "[Wuerstchen] fix fp16 training and correct lora args (#6245)" This reverts commit 0bb9cf0. * Revert "[docs] fix: animatediff docs (#6339)" This reverts commit 11659a6. * remove tokenize_prompt(). * assistive comments around enable_adapters() and diable_adapters(). --------- Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Younes Belkada <[email protected]> Co-authored-by: Fabio Rigano <[email protected]> Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: apolinário <[email protected]> Co-authored-by: Charchit Sharma <[email protected]> Co-authored-by: Aryan V S <[email protected]> Co-authored-by: dg845 <[email protected]> Co-authored-by: Kashif Rasul <[email protected]>
1 parent 4e7b0cb commit 6683f97

File tree

4 files changed

+1507
-1
lines changed

4 files changed

+1507
-1
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ def save_model_card(
161161
base_model: {base_model}
162162
instance_prompt: {instance_prompt}
163163
license: openrail++
164+
widget:
165+
- text: '{validation_prompt if validation_prompt else instance_prompt}'
164166
---
165167
"""
166168

examples/consistency_distillation/README_sdxl.md

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,38 @@ accelerate launch train_lcm_distill_lora_sdxl_wds.py \
111111
--report_to=wandb \
112112
--seed=453645634 \
113113
--push_to_hub \
114-
```
114+
```
115+
116+
We provide another version for LCM LoRA SDXL that follows best practices of `peft` and leverages the `datasets` library for quick experimentation. The script doesn't load two UNets unlike `train_lcm_distill_lora_sdxl_wds.py` which reduces the memory requirements quite a bit.
117+
118+
Below is an example training command that trains an LCM LoRA on the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions):
119+
120+
```bash
121+
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
122+
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
123+
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
124+
125+
accelerate launch train_lcm_distill_lora_sdxl.py \
126+
--pretrained_teacher_model=${MODEL_NAME} \
127+
--pretrained_vae_model_name_or_path=${VAE_PATH} \
128+
--output_dir="pokemons-lora-lcm-sdxl" \
129+
--mixed_precision="fp16" \
130+
--dataset_name=$DATASET_NAME \
131+
--resolution=1024 \
132+
--train_batch_size=24 \
133+
--gradient_accumulation_steps=1 \
134+
--gradient_checkpointing \
135+
--use_8bit_adam \
136+
--lora_rank=64 \
137+
--learning_rate=1e-4 \
138+
--report_to="wandb" \
139+
--lr_scheduler="constant" \
140+
--lr_warmup_steps=0 \
141+
--max_train_steps=3000 \
142+
--checkpointing_steps=500 \
143+
--validation_steps=50 \
144+
--seed="0" \
145+
--report_to="wandb" \
146+
--push_to_hub
147+
```
148+
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# coding=utf-8
2+
# Copyright 2023 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
import safetensors
22+
23+
24+
sys.path.append("..")
25+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
26+
27+
28+
logging.basicConfig(level=logging.DEBUG)
29+
30+
logger = logging.getLogger()
31+
stream_handler = logging.StreamHandler(sys.stdout)
32+
logger.addHandler(stream_handler)
33+
34+
35+
class TextToImageLCM(ExamplesTestsAccelerate):
36+
def test_text_to_image_lcm_lora_sdxl(self):
37+
with tempfile.TemporaryDirectory() as tmpdir:
38+
test_args = f"""
39+
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
40+
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
41+
--dataset_name hf-internal-testing/dummy_image_text_data
42+
--resolution 64
43+
--lora_rank 4
44+
--train_batch_size 1
45+
--gradient_accumulation_steps 1
46+
--max_train_steps 2
47+
--learning_rate 5.0e-04
48+
--scale_lr
49+
--lr_scheduler constant
50+
--lr_warmup_steps 0
51+
--output_dir {tmpdir}
52+
""".split()
53+
54+
run_command(self._launch_args + test_args)
55+
# save_pretrained smoke test
56+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
57+
58+
# make sure the state_dict has the correct naming in the parameters.
59+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
60+
is_lora = all("lora" in k for k in lora_state_dict.keys())
61+
self.assertTrue(is_lora)
62+
63+
def test_text_to_image_lcm_lora_sdxl_checkpointing(self):
64+
with tempfile.TemporaryDirectory() as tmpdir:
65+
test_args = f"""
66+
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
67+
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
68+
--dataset_name hf-internal-testing/dummy_image_text_data
69+
--resolution 64
70+
--lora_rank 4
71+
--train_batch_size 1
72+
--gradient_accumulation_steps 1
73+
--max_train_steps 7
74+
--checkpointing_steps 2
75+
--learning_rate 5.0e-04
76+
--scale_lr
77+
--lr_scheduler constant
78+
--lr_warmup_steps 0
79+
--output_dir {tmpdir}
80+
""".split()
81+
82+
run_command(self._launch_args + test_args)
83+
84+
self.assertEqual(
85+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
86+
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
87+
)
88+
89+
test_args = f"""
90+
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
91+
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
92+
--dataset_name hf-internal-testing/dummy_image_text_data
93+
--resolution 64
94+
--lora_rank 4
95+
--train_batch_size 1
96+
--gradient_accumulation_steps 1
97+
--max_train_steps 9
98+
--checkpointing_steps 2
99+
--resume_from_checkpoint latest
100+
--learning_rate 5.0e-04
101+
--scale_lr
102+
--lr_scheduler constant
103+
--lr_warmup_steps 0
104+
--output_dir {tmpdir}
105+
""".split()
106+
107+
run_command(self._launch_args + test_args)
108+
109+
self.assertEqual(
110+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
111+
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
112+
)

0 commit comments

Comments
 (0)