Skip to content

Commit e912ff8

Browse files
authored
Merge branch 'main' into lora_modules
2 parents 8c95792 + 5704376 commit e912ff8

27 files changed

+4549
-114
lines changed

docs/source/en/api/pipelines/cogvideox.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ There are two models available that can be used with the text-to-video and video
3636
There is one model available that can be used with the image-to-video CogVideoX pipeline:
3737
- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`.
3838

39+
There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team):
40+
- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`.
41+
- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`.
42+
3943
## Inference
4044

4145
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
@@ -118,6 +122,12 @@ It is also worth noting that torchao quantization is fully compatible with [torc
118122
- all
119123
- __call__
120124

125+
## CogVideoXFunControlPipeline
126+
127+
[[autodoc]] CogVideoXFunControlPipeline
128+
- all
129+
- __call__
130+
121131
## CogVideoXPipelineOutput
122132

123133
[[autodoc]] pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput

examples/advanced_diffusion_training/README_flux.md

Lines changed: 353 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
accelerate>=0.31.0
2+
torchvision
3+
transformers>=4.41.2
4+
ftfy
5+
tensorboard
6+
Jinja2
7+
peft>=0.11.1
8+
sentencepiece
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
import safetensors
22+
23+
24+
sys.path.append("..")
25+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
26+
27+
28+
logging.basicConfig(level=logging.DEBUG)
29+
30+
logger = logging.getLogger()
31+
stream_handler = logging.StreamHandler(sys.stdout)
32+
logger.addHandler(stream_handler)
33+
34+
35+
class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate):
36+
instance_data_dir = "docs/source/en/imgs"
37+
instance_prompt = "photo"
38+
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
39+
script_path = "examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py"
40+
41+
def test_dreambooth_lora_flux(self):
42+
with tempfile.TemporaryDirectory() as tmpdir:
43+
test_args = f"""
44+
{self.script_path}
45+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
46+
--instance_data_dir {self.instance_data_dir}
47+
--instance_prompt {self.instance_prompt}
48+
--resolution 64
49+
--train_batch_size 1
50+
--gradient_accumulation_steps 1
51+
--max_train_steps 2
52+
--learning_rate 5.0e-04
53+
--scale_lr
54+
--lr_scheduler constant
55+
--lr_warmup_steps 0
56+
--output_dir {tmpdir}
57+
""".split()
58+
59+
run_command(self._launch_args + test_args)
60+
# save_pretrained smoke test
61+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
62+
63+
# make sure the state_dict has the correct naming in the parameters.
64+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
65+
is_lora = all("lora" in k for k in lora_state_dict.keys())
66+
self.assertTrue(is_lora)
67+
68+
# when not training the text encoder, all the parameters in the state dict should start
69+
# with `"transformer"` in their names.
70+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
71+
self.assertTrue(starts_with_transformer)
72+
73+
def test_dreambooth_lora_text_encoder_flux(self):
74+
with tempfile.TemporaryDirectory() as tmpdir:
75+
test_args = f"""
76+
{self.script_path}
77+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
78+
--instance_data_dir {self.instance_data_dir}
79+
--instance_prompt {self.instance_prompt}
80+
--resolution 64
81+
--train_batch_size 1
82+
--train_text_encoder
83+
--gradient_accumulation_steps 1
84+
--max_train_steps 2
85+
--learning_rate 5.0e-04
86+
--scale_lr
87+
--lr_scheduler constant
88+
--lr_warmup_steps 0
89+
--output_dir {tmpdir}
90+
""".split()
91+
92+
run_command(self._launch_args + test_args)
93+
# save_pretrained smoke test
94+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
95+
96+
# make sure the state_dict has the correct naming in the parameters.
97+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
98+
is_lora = all("lora" in k for k in lora_state_dict.keys())
99+
self.assertTrue(is_lora)
100+
101+
starts_with_expected_prefix = all(
102+
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
103+
)
104+
self.assertTrue(starts_with_expected_prefix)
105+
106+
def test_dreambooth_lora_pivotal_tuning_flux_clip(self):
107+
with tempfile.TemporaryDirectory() as tmpdir:
108+
test_args = f"""
109+
{self.script_path}
110+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
111+
--instance_data_dir {self.instance_data_dir}
112+
--instance_prompt {self.instance_prompt}
113+
--resolution 64
114+
--train_batch_size 1
115+
--train_text_encoder_ti
116+
--gradient_accumulation_steps 1
117+
--max_train_steps 2
118+
--learning_rate 5.0e-04
119+
--scale_lr
120+
--lr_scheduler constant
121+
--lr_warmup_steps 0
122+
--output_dir {tmpdir}
123+
""".split()
124+
125+
run_command(self._launch_args + test_args)
126+
# save_pretrained smoke test
127+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
128+
# make sure embeddings were also saved
129+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")))
130+
131+
# make sure the state_dict has the correct naming in the parameters.
132+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
133+
is_lora = all("lora" in k for k in lora_state_dict.keys())
134+
self.assertTrue(is_lora)
135+
136+
# make sure the state_dict has the correct naming in the parameters.
137+
textual_inversion_state_dict = safetensors.torch.load_file(
138+
os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")
139+
)
140+
is_clip = all("clip_l" in k for k in textual_inversion_state_dict.keys())
141+
self.assertTrue(is_clip)
142+
143+
# when performing pivotal tuning, all the parameters in the state dict should start
144+
# with `"transformer"` in their names.
145+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
146+
self.assertTrue(starts_with_transformer)
147+
148+
def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self):
149+
with tempfile.TemporaryDirectory() as tmpdir:
150+
test_args = f"""
151+
{self.script_path}
152+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
153+
--instance_data_dir {self.instance_data_dir}
154+
--instance_prompt {self.instance_prompt}
155+
--resolution 64
156+
--train_batch_size 1
157+
--train_text_encoder_ti
158+
--enable_t5_ti
159+
--gradient_accumulation_steps 1
160+
--max_train_steps 2
161+
--learning_rate 5.0e-04
162+
--scale_lr
163+
--lr_scheduler constant
164+
--lr_warmup_steps 0
165+
--output_dir {tmpdir}
166+
""".split()
167+
168+
run_command(self._launch_args + test_args)
169+
# save_pretrained smoke test
170+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
171+
# make sure embeddings were also saved
172+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")))
173+
174+
# make sure the state_dict has the correct naming in the parameters.
175+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
176+
is_lora = all("lora" in k for k in lora_state_dict.keys())
177+
self.assertTrue(is_lora)
178+
179+
# make sure the state_dict has the correct naming in the parameters.
180+
textual_inversion_state_dict = safetensors.torch.load_file(
181+
os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")
182+
)
183+
is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys())
184+
self.assertTrue(is_te)
185+
186+
# when performing pivotal tuning, all the parameters in the state dict should start
187+
# with `"transformer"` in their names.
188+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
189+
self.assertTrue(starts_with_transformer)
190+
191+
def test_dreambooth_lora_latent_caching(self):
192+
with tempfile.TemporaryDirectory() as tmpdir:
193+
test_args = f"""
194+
{self.script_path}
195+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
196+
--instance_data_dir {self.instance_data_dir}
197+
--instance_prompt {self.instance_prompt}
198+
--resolution 64
199+
--train_batch_size 1
200+
--gradient_accumulation_steps 1
201+
--max_train_steps 2
202+
--cache_latents
203+
--learning_rate 5.0e-04
204+
--scale_lr
205+
--lr_scheduler constant
206+
--lr_warmup_steps 0
207+
--output_dir {tmpdir}
208+
""".split()
209+
210+
run_command(self._launch_args + test_args)
211+
# save_pretrained smoke test
212+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
213+
214+
# make sure the state_dict has the correct naming in the parameters.
215+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
216+
is_lora = all("lora" in k for k in lora_state_dict.keys())
217+
self.assertTrue(is_lora)
218+
219+
# when not training the text encoder, all the parameters in the state dict should start
220+
# with `"transformer"` in their names.
221+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
222+
self.assertTrue(starts_with_transformer)
223+
224+
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
225+
with tempfile.TemporaryDirectory() as tmpdir:
226+
test_args = f"""
227+
{self.script_path}
228+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
229+
--instance_data_dir={self.instance_data_dir}
230+
--output_dir={tmpdir}
231+
--instance_prompt={self.instance_prompt}
232+
--resolution=64
233+
--train_batch_size=1
234+
--gradient_accumulation_steps=1
235+
--max_train_steps=6
236+
--checkpoints_total_limit=2
237+
--checkpointing_steps=2
238+
""".split()
239+
240+
run_command(self._launch_args + test_args)
241+
242+
self.assertEqual(
243+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
244+
{"checkpoint-4", "checkpoint-6"},
245+
)
246+
247+
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
248+
with tempfile.TemporaryDirectory() as tmpdir:
249+
test_args = f"""
250+
{self.script_path}
251+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
252+
--instance_data_dir={self.instance_data_dir}
253+
--output_dir={tmpdir}
254+
--instance_prompt={self.instance_prompt}
255+
--resolution=64
256+
--train_batch_size=1
257+
--gradient_accumulation_steps=1
258+
--max_train_steps=4
259+
--checkpointing_steps=2
260+
""".split()
261+
262+
run_command(self._launch_args + test_args)
263+
264+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
265+
266+
resume_run_args = f"""
267+
{self.script_path}
268+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
269+
--instance_data_dir={self.instance_data_dir}
270+
--output_dir={tmpdir}
271+
--instance_prompt={self.instance_prompt}
272+
--resolution=64
273+
--train_batch_size=1
274+
--gradient_accumulation_steps=1
275+
--max_train_steps=8
276+
--checkpointing_steps=2
277+
--resume_from_checkpoint=checkpoint-4
278+
--checkpoints_total_limit=2
279+
""".split()
280+
281+
run_command(self._launch_args + resume_run_args)
282+
283+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

0 commit comments

Comments
 (0)