Skip to content

Commit 65e3090

Browse files
linoytsabanbghirasayakpaul
authored
[Flux] Dreambooth LoRA training scripts (#9086)
* initial commit - dreambooth for flux * update transformer to be FluxTransformer2DModel * update training loop and validation inference * fix sd3->flux docs * add guidance handling, not sure if it makes sense(?) * inital dreambooth lora commit * fix text_ids in compute_text_embeddings * fix imports of static methods * fix pipeline loading in readme, remove auto1111 docs for now * fix pipeline loading in readme, remove auto1111 docs for now, remove some irrelevant text_encoder_3 refs * Update examples/dreambooth/train_dreambooth_flux.py Co-authored-by: Bagheera <[email protected]> * fix te2 loading and remove te2 refs from text encoder training * fix tokenizer_2 initialization * remove text_encoder training refs from lora script (for now) * try with vae in bfloat16, fix model hook save * fix tokenization * fix static imports * fix CLIP import * remove text_encoder training refs (for now) from lora script * fix minor bug in encode_prompt, add guidance def in lora script, ... * fix unpack_latents args * fix license in readme * add "none" to weighting_scheme options for uniform sampling * style * adapt model saving - remove text encoder refs * adapt model loading - remove text encoder refs * initial commit for readme * Update examples/dreambooth/train_dreambooth_lora_flux.py Co-authored-by: Sayak Paul <[email protected]> * Update examples/dreambooth/train_dreambooth_lora_flux.py Co-authored-by: Sayak Paul <[email protected]> * fix vae casting * remove precondition_outputs * readme * readme * style * readme * readme * update weighting scheme default & docs * style * add text_encoder training to lora script, change vae_scale_factor value in both * style * text encoder training fixes * style * update readme * minor fixes * fix te params * fix te params --------- Co-authored-by: Bagheera <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent cee7c1b commit 65e3090

File tree

3 files changed

+3806
-0
lines changed

3 files changed

+3806
-0
lines changed

examples/dreambooth/README_flux.md

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# DreamBooth training example for FLUX.1 [dev]
2+
3+
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
4+
5+
The `train_dreambooth_flux.py` script shows how to implement the training procedure and adapt it for [FLUX.1 [dev]](https://blackforestlabs.ai/announcing-black-forest-labs/). We also provide a LoRA implementation in the `train_dreambooth_lora_flux.py` script.
6+
> [!NOTE]
7+
> **Memory consumption**
8+
>
9+
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
10+
> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.
11+
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](documentation/quickstart/FLUX.md)
12+
13+
14+
> [!NOTE]
15+
> **Gated model**
16+
>
17+
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
18+
19+
```bash
20+
huggingface-cli login
21+
```
22+
23+
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
24+
25+
## Running locally with PyTorch
26+
27+
### Installing the dependencies
28+
29+
Before running the scripts, make sure to install the library's training dependencies:
30+
31+
**Important**
32+
33+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
34+
35+
```bash
36+
git clone https://github.com/huggingface/diffusers
37+
cd diffusers
38+
pip install -e .
39+
```
40+
41+
Then cd in the `examples/dreambooth` folder and run
42+
```bash
43+
pip install -r requirements_flux.txt
44+
```
45+
46+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
47+
48+
```bash
49+
accelerate config
50+
```
51+
52+
Or for a default accelerate configuration without answering questions about your environment
53+
54+
```bash
55+
accelerate config default
56+
```
57+
58+
Or if your environment doesn't support an interactive shell (e.g., a notebook)
59+
60+
```python
61+
from accelerate.utils import write_basic_config
62+
write_basic_config()
63+
```
64+
65+
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
66+
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
67+
68+
69+
### Dog toy example
70+
71+
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
72+
73+
Let's first download it locally:
74+
75+
```python
76+
from huggingface_hub import snapshot_download
77+
78+
local_dir = "./dog"
79+
snapshot_download(
80+
"diffusers/dog-example",
81+
local_dir=local_dir, repo_type="dataset",
82+
ignore_patterns=".gitattributes",
83+
)
84+
```
85+
86+
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
87+
88+
Now, we can launch training using:
89+
90+
```bash
91+
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
92+
export INSTANCE_DIR="dog"
93+
export OUTPUT_DIR="trained-flux"
94+
95+
accelerate launch train_dreambooth_flux.py \
96+
--pretrained_model_name_or_path=$MODEL_NAME \
97+
--instance_data_dir=$INSTANCE_DIR \
98+
--output_dir=$OUTPUT_DIR \
99+
--mixed_precision="fp16" \
100+
--instance_prompt="a photo of sks dog" \
101+
--resolution=1024 \
102+
--train_batch_size=1 \
103+
--gradient_accumulation_steps=4 \
104+
--learning_rate=1e-4 \
105+
--report_to="wandb" \
106+
--lr_scheduler="constant" \
107+
--lr_warmup_steps=0 \
108+
--max_train_steps=500 \
109+
--validation_prompt="A photo of sks dog in a bucket" \
110+
--validation_epochs=25 \
111+
--seed="0" \
112+
--push_to_hub
113+
```
114+
115+
To better track our training experiments, we're using the following flags in the command above:
116+
117+
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
118+
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
119+
120+
> [!NOTE]
121+
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
122+
123+
> [!TIP]
124+
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
125+
126+
## LoRA + DreamBooth
127+
128+
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
129+
130+
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
131+
132+
To perform DreamBooth with LoRA, run:
133+
134+
```bash
135+
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
136+
export INSTANCE_DIR="dog"
137+
export OUTPUT_DIR="trained-flux-lora"
138+
139+
accelerate launch train_dreambooth_lora_flux.py \
140+
--pretrained_model_name_or_path=$MODEL_NAME \
141+
--instance_data_dir=$INSTANCE_DIR \
142+
--output_dir=$OUTPUT_DIR \
143+
--mixed_precision="fp16" \
144+
--instance_prompt="a photo of sks dog" \
145+
--resolution=512 \
146+
--train_batch_size=1 \
147+
--gradient_accumulation_steps=4 \
148+
--learning_rate=1e-5 \
149+
--report_to="wandb" \
150+
--lr_scheduler="constant" \
151+
--lr_warmup_steps=0 \
152+
--max_train_steps=500 \
153+
--validation_prompt="A photo of sks dog in a bucket" \
154+
--validation_epochs=25 \
155+
--seed="0" \
156+
--push_to_hub
157+
```
158+
159+
### Text Encoder Training
160+
161+
Alongside the transformer, fine-tuning of the CLIP text encoder is also supported.
162+
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
163+
164+
> [!NOTE]
165+
> FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL).
166+
By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed.
167+
> At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
168+
169+
To perform DreamBooth LoRA with text-encoder training, run:
170+
```bash
171+
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
172+
export OUTPUT_DIR="trained-flux-dev-dreambooth-lora"
173+
174+
accelerate launch train_dreambooth_lora_flux.py \
175+
--pretrained_model_name_or_path=$MODEL_NAME \
176+
--instance_data_dir=$INSTANCE_DIR \
177+
--output_dir=$OUTPUT_DIR \
178+
--mixed_precision="fp16" \
179+
--train_text_encoder\
180+
--instance_prompt="a photo of sks dog" \
181+
--resolution=512 \
182+
--train_batch_size=1 \
183+
--gradient_accumulation_steps=4 \
184+
--learning_rate=1e-5 \
185+
--report_to="wandb" \
186+
--lr_scheduler="constant" \
187+
--lr_warmup_steps=0 \
188+
--max_train_steps=500 \
189+
--validation_prompt="A photo of sks dog in a bucket" \
190+
--seed="0" \
191+
--push_to_hub
192+
```
193+
194+
## Other notes
195+
Thanks to `bghira` for their help with reviewing & insight sharing ♥️

0 commit comments

Comments
 (0)