Skip to content

Commit 413ca29

Browse files
authored
[Flux Dreambooth LoRA] - te bug fixes & updates (#9139)
* add requirements + fix link to bghira's guide * text ecnoder training fixes * text encoder training fixes * text encoder training fixes * text encoder training fixes * style * add tests * fix encode_prompt call * style * unpack_latents test * fix lora saving * remove default val for max_sequenece_length in encode_prompt * remove default val for max_sequenece_length in encode_prompt * style * testing * style * testing * testing * style * fix sizing issue * style * revert scaling * style * style * scaling test * style * scaling test * remove model pred operation left from pre-conditioning * remove model pred operation left from pre-conditioning * fix trainable params * remove te2 from casting * transformer to accelerator * remove prints * empty commit
1 parent 10dc06c commit 413ca29

File tree

6 files changed

+488
-92
lines changed

6 files changed

+488
-92
lines changed

examples/dreambooth/README_flux.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced
88
>
99
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
1010
> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.
11-
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](documentation/quickstart/FLUX.md)
11+
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)
1212
1313

1414
> [!NOTE]
@@ -96,7 +96,7 @@ accelerate launch train_dreambooth_flux.py \
9696
--pretrained_model_name_or_path=$MODEL_NAME \
9797
--instance_data_dir=$INSTANCE_DIR \
9898
--output_dir=$OUTPUT_DIR \
99-
--mixed_precision="fp16" \
99+
--mixed_precision="bf16" \
100100
--instance_prompt="a photo of sks dog" \
101101
--resolution=1024 \
102102
--train_batch_size=1 \
@@ -140,7 +140,7 @@ accelerate launch train_dreambooth_lora_flux.py \
140140
--pretrained_model_name_or_path=$MODEL_NAME \
141141
--instance_data_dir=$INSTANCE_DIR \
142142
--output_dir=$OUTPUT_DIR \
143-
--mixed_precision="fp16" \
143+
--mixed_precision="bf16" \
144144
--instance_prompt="a photo of sks dog" \
145145
--resolution=512 \
146146
--train_batch_size=1 \
@@ -175,7 +175,7 @@ accelerate launch train_dreambooth_lora_flux.py \
175175
--pretrained_model_name_or_path=$MODEL_NAME \
176176
--instance_data_dir=$INSTANCE_DIR \
177177
--output_dir=$OUTPUT_DIR \
178-
--mixed_precision="fp16" \
178+
--mixed_precision="bf16" \
179179
--train_text_encoder\
180180
--instance_prompt="a photo of sks dog" \
181181
--resolution=512 \
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
accelerate>=0.31.0
2+
torchvision
3+
transformers>=4.41.2
4+
ftfy
5+
tensorboard
6+
Jinja2
7+
peft>=0.11.1
8+
sentencepiece
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import shutil
19+
import sys
20+
import tempfile
21+
22+
from diffusers import DiffusionPipeline, FluxTransformer2DModel
23+
24+
25+
sys.path.append("..")
26+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
27+
28+
29+
logging.basicConfig(level=logging.DEBUG)
30+
31+
logger = logging.getLogger()
32+
stream_handler = logging.StreamHandler(sys.stdout)
33+
logger.addHandler(stream_handler)
34+
35+
36+
class DreamBoothFlux(ExamplesTestsAccelerate):
37+
instance_data_dir = "docs/source/en/imgs"
38+
instance_prompt = "photo"
39+
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
40+
script_path = "examples/dreambooth/train_dreambooth_flux.py"
41+
42+
def test_dreambooth(self):
43+
with tempfile.TemporaryDirectory() as tmpdir:
44+
test_args = f"""
45+
{self.script_path}
46+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
47+
--instance_data_dir {self.instance_data_dir}
48+
--instance_prompt {self.instance_prompt}
49+
--resolution 64
50+
--train_batch_size 1
51+
--gradient_accumulation_steps 1
52+
--max_train_steps 2
53+
--learning_rate 5.0e-04
54+
--scale_lr
55+
--lr_scheduler constant
56+
--lr_warmup_steps 0
57+
--output_dir {tmpdir}
58+
""".split()
59+
60+
run_command(self._launch_args + test_args)
61+
# save_pretrained smoke test
62+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors")))
63+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
64+
65+
def test_dreambooth_checkpointing(self):
66+
with tempfile.TemporaryDirectory() as tmpdir:
67+
# Run training script with checkpointing
68+
# max_train_steps == 4, checkpointing_steps == 2
69+
# Should create checkpoints at steps 2, 4
70+
71+
initial_run_args = f"""
72+
{self.script_path}
73+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
74+
--instance_data_dir {self.instance_data_dir}
75+
--instance_prompt {self.instance_prompt}
76+
--resolution 64
77+
--train_batch_size 1
78+
--gradient_accumulation_steps 1
79+
--max_train_steps 4
80+
--learning_rate 5.0e-04
81+
--scale_lr
82+
--lr_scheduler constant
83+
--lr_warmup_steps 0
84+
--output_dir {tmpdir}
85+
--checkpointing_steps=2
86+
--seed=0
87+
""".split()
88+
89+
run_command(self._launch_args + initial_run_args)
90+
91+
# check can run the original fully trained output pipeline
92+
pipe = DiffusionPipeline.from_pretrained(tmpdir)
93+
pipe(self.instance_prompt, num_inference_steps=1)
94+
95+
# check checkpoint directories exist
96+
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
97+
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
98+
99+
# check can run an intermediate checkpoint
100+
transformer = FluxTransformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer")
101+
pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)
102+
pipe(self.instance_prompt, num_inference_steps=1)
103+
104+
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
105+
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
106+
107+
# Run training script for 7 total steps resuming from checkpoint 4
108+
109+
resume_run_args = f"""
110+
{self.script_path}
111+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
112+
--instance_data_dir {self.instance_data_dir}
113+
--instance_prompt {self.instance_prompt}
114+
--resolution 64
115+
--train_batch_size 1
116+
--gradient_accumulation_steps 1
117+
--max_train_steps 6
118+
--learning_rate 5.0e-04
119+
--scale_lr
120+
--lr_scheduler constant
121+
--lr_warmup_steps 0
122+
--output_dir {tmpdir}
123+
--checkpointing_steps=2
124+
--resume_from_checkpoint=checkpoint-4
125+
--seed=0
126+
""".split()
127+
128+
run_command(self._launch_args + resume_run_args)
129+
130+
# check can run new fully trained pipeline
131+
pipe = DiffusionPipeline.from_pretrained(tmpdir)
132+
pipe(self.instance_prompt, num_inference_steps=1)
133+
134+
# check old checkpoints do not exist
135+
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
136+
137+
# check new checkpoints exist
138+
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
139+
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
140+
141+
def test_dreambooth_checkpointing_checkpoints_total_limit(self):
142+
with tempfile.TemporaryDirectory() as tmpdir:
143+
test_args = f"""
144+
{self.script_path}
145+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
146+
--instance_data_dir={self.instance_data_dir}
147+
--output_dir={tmpdir}
148+
--instance_prompt={self.instance_prompt}
149+
--resolution=64
150+
--train_batch_size=1
151+
--gradient_accumulation_steps=1
152+
--max_train_steps=6
153+
--checkpoints_total_limit=2
154+
--checkpointing_steps=2
155+
""".split()
156+
157+
run_command(self._launch_args + test_args)
158+
159+
self.assertEqual(
160+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
161+
{"checkpoint-4", "checkpoint-6"},
162+
)
163+
164+
def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
165+
with tempfile.TemporaryDirectory() as tmpdir:
166+
test_args = f"""
167+
{self.script_path}
168+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
169+
--instance_data_dir={self.instance_data_dir}
170+
--output_dir={tmpdir}
171+
--instance_prompt={self.instance_prompt}
172+
--resolution=64
173+
--train_batch_size=1
174+
--gradient_accumulation_steps=1
175+
--max_train_steps=4
176+
--checkpointing_steps=2
177+
""".split()
178+
179+
run_command(self._launch_args + test_args)
180+
181+
self.assertEqual(
182+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
183+
{"checkpoint-2", "checkpoint-4"},
184+
)
185+
186+
resume_run_args = f"""
187+
{self.script_path}
188+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
189+
--instance_data_dir={self.instance_data_dir}
190+
--output_dir={tmpdir}
191+
--instance_prompt={self.instance_prompt}
192+
--resolution=64
193+
--train_batch_size=1
194+
--gradient_accumulation_steps=1
195+
--max_train_steps=8
196+
--checkpointing_steps=2
197+
--resume_from_checkpoint=checkpoint-4
198+
--checkpoints_total_limit=2
199+
""".split()
200+
201+
run_command(self._launch_args + resume_run_args)
202+
203+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

0 commit comments

Comments
 (0)