Skip to content

Commit 25d1ee1

Browse files
authored
Merge branch 'main' into refactor-instructpix2pix_lora-toSupport-peft
2 parents f037e79 + 055d955 commit 25d1ee1

30 files changed

+742
-285
lines changed

docs/source/en/api/models/autoencoder_kl_hunyuan_video.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
1818
```python
1919
from diffusers import AutoencoderKLHunyuanVideo
2020

21-
vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16)
21+
vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="vae", torch_dtype=torch.float16)
2222
```
2323

2424
## AutoencoderKLHunyuanVideo

docs/source/en/api/models/hunyuan_video_transformer_3d.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
1818
```python
1919
from diffusers import HunyuanVideoTransformer3DModel
2020

21-
transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16)
21+
transformer = HunyuanVideoTransformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16)
2222
```
2323

2424
## HunyuanVideoTransformer3DModel

docs/source/en/api/models/sana_transformer2d.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ The model can be loaded with the following code snippet.
2222
```python
2323
from diffusers import SanaTransformer2DModel
2424

25-
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
25+
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
2626
```
2727

2828
## SanaTransformer2DModel

docs/source/en/api/pipelines/hunyuan_video.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Recommendations for inference:
2929
- Transformer should be in `torch.bfloat16`.
3030
- VAE should be in `torch.float16`.
3131
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
32-
- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
32+
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
3333
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).
3434

3535
## HunyuanVideoPipeline

docs/source/en/api/pipelines/sana.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ Available models:
3232

3333
| Model | Recommended dtype |
3434
|:-----:|:-----------------:|
35+
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
3536
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
3637
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
37-
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
3838
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
3939
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
4040
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |

docs/source/en/quantization/torchao.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ The example below only quantizes the weights to int8.
2727
```python
2828
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
2929

30-
model_id = "black-forest-labs/Flux.1-Dev"
30+
model_id = "black-forest-labs/FLUX.1-dev"
3131
dtype = torch.bfloat16
3232

3333
quantization_config = TorchAoConfig("int8wo")
@@ -45,7 +45,9 @@ pipe = FluxPipeline.from_pretrained(
4545
pipe.to("cuda")
4646

4747
prompt = "A cat holding a sign that says hello world"
48-
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0]
48+
image = pipe(
49+
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
50+
).images[0]
4951
image.save("output.png")
5052
```
5153

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
import safetensors
22+
23+
24+
sys.path.append("..")
25+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
26+
27+
28+
logging.basicConfig(level=logging.DEBUG)
29+
30+
logger = logging.getLogger()
31+
stream_handler = logging.StreamHandler(sys.stdout)
32+
logger.addHandler(stream_handler)
33+
34+
35+
class DreamBoothLoRASANA(ExamplesTestsAccelerate):
36+
instance_data_dir = "docs/source/en/imgs"
37+
pretrained_model_name_or_path = "hf-internal-testing/tiny-sana-pipe"
38+
script_path = "examples/dreambooth/train_dreambooth_lora_sana.py"
39+
transformer_layer_type = "transformer_blocks.0.attn1.to_k"
40+
41+
def test_dreambooth_lora_sana(self):
42+
with tempfile.TemporaryDirectory() as tmpdir:
43+
test_args = f"""
44+
{self.script_path}
45+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
46+
--instance_data_dir {self.instance_data_dir}
47+
--resolution 32
48+
--train_batch_size 1
49+
--gradient_accumulation_steps 1
50+
--max_train_steps 2
51+
--learning_rate 5.0e-04
52+
--scale_lr
53+
--lr_scheduler constant
54+
--lr_warmup_steps 0
55+
--output_dir {tmpdir}
56+
--max_sequence_length 16
57+
""".split()
58+
59+
test_args.extend(["--instance_prompt", ""])
60+
run_command(self._launch_args + test_args)
61+
# save_pretrained smoke test
62+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
63+
64+
# make sure the state_dict has the correct naming in the parameters.
65+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
66+
is_lora = all("lora" in k for k in lora_state_dict.keys())
67+
self.assertTrue(is_lora)
68+
69+
# when not training the text encoder, all the parameters in the state dict should start
70+
# with `"transformer"` in their names.
71+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
72+
self.assertTrue(starts_with_transformer)
73+
74+
def test_dreambooth_lora_latent_caching(self):
75+
with tempfile.TemporaryDirectory() as tmpdir:
76+
test_args = f"""
77+
{self.script_path}
78+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
79+
--instance_data_dir {self.instance_data_dir}
80+
--resolution 32
81+
--train_batch_size 1
82+
--gradient_accumulation_steps 1
83+
--max_train_steps 2
84+
--cache_latents
85+
--learning_rate 5.0e-04
86+
--scale_lr
87+
--lr_scheduler constant
88+
--lr_warmup_steps 0
89+
--output_dir {tmpdir}
90+
--max_sequence_length 16
91+
""".split()
92+
93+
test_args.extend(["--instance_prompt", ""])
94+
run_command(self._launch_args + test_args)
95+
# save_pretrained smoke test
96+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
97+
98+
# make sure the state_dict has the correct naming in the parameters.
99+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
100+
is_lora = all("lora" in k for k in lora_state_dict.keys())
101+
self.assertTrue(is_lora)
102+
103+
# when not training the text encoder, all the parameters in the state dict should start
104+
# with `"transformer"` in their names.
105+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
106+
self.assertTrue(starts_with_transformer)
107+
108+
def test_dreambooth_lora_layers(self):
109+
with tempfile.TemporaryDirectory() as tmpdir:
110+
test_args = f"""
111+
{self.script_path}
112+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
113+
--instance_data_dir {self.instance_data_dir}
114+
--resolution 32
115+
--train_batch_size 1
116+
--gradient_accumulation_steps 1
117+
--max_train_steps 2
118+
--cache_latents
119+
--learning_rate 5.0e-04
120+
--scale_lr
121+
--lora_layers {self.transformer_layer_type}
122+
--lr_scheduler constant
123+
--lr_warmup_steps 0
124+
--output_dir {tmpdir}
125+
--max_sequence_length 16
126+
""".split()
127+
128+
test_args.extend(["--instance_prompt", ""])
129+
run_command(self._launch_args + test_args)
130+
# save_pretrained smoke test
131+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
132+
133+
# make sure the state_dict has the correct naming in the parameters.
134+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
135+
is_lora = all("lora" in k for k in lora_state_dict.keys())
136+
self.assertTrue(is_lora)
137+
138+
# when not training the text encoder, all the parameters in the state dict should start
139+
# with `"transformer"` in their names. In this test, we only params of
140+
# `self.transformer_layer_type` should be in the state dict.
141+
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
142+
self.assertTrue(starts_with_transformer)
143+
144+
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit(self):
145+
with tempfile.TemporaryDirectory() as tmpdir:
146+
test_args = f"""
147+
{self.script_path}
148+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
149+
--instance_data_dir={self.instance_data_dir}
150+
--output_dir={tmpdir}
151+
--resolution=32
152+
--train_batch_size=1
153+
--gradient_accumulation_steps=1
154+
--max_train_steps=6
155+
--checkpoints_total_limit=2
156+
--checkpointing_steps=2
157+
--max_sequence_length 16
158+
""".split()
159+
160+
test_args.extend(["--instance_prompt", ""])
161+
run_command(self._launch_args + test_args)
162+
163+
self.assertEqual(
164+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
165+
{"checkpoint-4", "checkpoint-6"},
166+
)
167+
168+
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
169+
with tempfile.TemporaryDirectory() as tmpdir:
170+
test_args = f"""
171+
{self.script_path}
172+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
173+
--instance_data_dir={self.instance_data_dir}
174+
--output_dir={tmpdir}
175+
--resolution=32
176+
--train_batch_size=1
177+
--gradient_accumulation_steps=1
178+
--max_train_steps=4
179+
--checkpointing_steps=2
180+
--max_sequence_length 166
181+
""".split()
182+
183+
test_args.extend(["--instance_prompt", ""])
184+
run_command(self._launch_args + test_args)
185+
186+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
187+
188+
resume_run_args = f"""
189+
{self.script_path}
190+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
191+
--instance_data_dir={self.instance_data_dir}
192+
--output_dir={tmpdir}
193+
--resolution=32
194+
--train_batch_size=1
195+
--gradient_accumulation_steps=1
196+
--max_train_steps=8
197+
--checkpointing_steps=2
198+
--resume_from_checkpoint=checkpoint-4
199+
--checkpoints_total_limit=2
200+
--max_sequence_length 16
201+
""".split()
202+
203+
resume_run_args.extend(["--instance_prompt", ""])
204+
run_command(self._launch_args + resume_run_args)
205+
206+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

examples/dreambooth/train_dreambooth_lora_sana.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,7 @@ def main(args):
943943

944944
# Load scheduler and models
945945
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
946-
args.pretrained_model_name_or_path, subfolder="scheduler"
946+
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision
947947
)
948948
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
949949
text_encoder = Gemma2Model.from_pretrained(
@@ -964,15 +964,6 @@ def main(args):
964964
vae.requires_grad_(False)
965965
text_encoder.requires_grad_(False)
966966

967-
# Initialize a text encoding pipeline and keep it to CPU for now.
968-
text_encoding_pipeline = SanaPipeline.from_pretrained(
969-
args.pretrained_model_name_or_path,
970-
vae=None,
971-
transformer=None,
972-
text_encoder=text_encoder,
973-
tokenizer=tokenizer,
974-
)
975-
976967
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
977968
# as these weights are only used for inference, keeping weights in full precision is not required.
978969
weight_dtype = torch.float32
@@ -993,6 +984,15 @@ def main(args):
993984
# because Gemma2 is particularly suited for bfloat16.
994985
text_encoder.to(dtype=torch.bfloat16)
995986

987+
# Initialize a text encoding pipeline and keep it to CPU for now.
988+
text_encoding_pipeline = SanaPipeline.from_pretrained(
989+
args.pretrained_model_name_or_path,
990+
vae=None,
991+
transformer=None,
992+
text_encoder=text_encoder,
993+
tokenizer=tokenizer,
994+
)
995+
996996
if args.gradient_checkpointing:
997997
transformer.enable_gradient_checkpointing()
998998

@@ -1182,6 +1182,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
11821182
)
11831183
if args.offload:
11841184
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1185+
prompt_embeds = prompt_embeds.to(transformer.dtype)
11851186
return prompt_embeds, prompt_attention_mask
11861187

11871188
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
@@ -1216,7 +1217,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
12161217
vae_config_scaling_factor = vae.config.scaling_factor
12171218
if args.cache_latents:
12181219
latents_cache = []
1219-
vae = vae.to("cuda")
1220+
vae = vae.to(accelerator.device)
12201221
for batch in tqdm(train_dataloader, desc="Caching latents"):
12211222
with torch.no_grad():
12221223
batch["pixel_values"] = batch["pixel_values"].to(

scripts/convert_sana_to_diffusers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,18 @@ def main(args):
8888
# y norm
8989
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
9090

91+
# scheduler
9192
flow_shift = 3.0
93+
94+
# model config
9295
if args.model_type == "SanaMS_1600M_P1_D20":
9396
layer_num = 20
9497
elif args.model_type == "SanaMS_600M_P1_D28":
9598
layer_num = 28
9699
else:
97100
raise ValueError(f"{args.model_type} is not supported.")
101+
# Positional embedding interpolation scale.
102+
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
98103

99104
for depth in range(layer_num):
100105
# Transformer blocks.
@@ -176,6 +181,7 @@ def main(args):
176181
patch_size=1,
177182
norm_elementwise_affine=False,
178183
norm_eps=1e-6,
184+
interpolation_scale=interpolation_scale[args.image_size],
179185
)
180186

181187
if is_accelerate_available():

src/diffusers/loaders/single_file_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
convert_autoencoder_dc_checkpoint_to_diffusers,
2929
convert_controlnet_checkpoint,
3030
convert_flux_transformer_checkpoint_to_diffusers,
31+
convert_hunyuan_video_transformer_to_diffusers,
3132
convert_ldm_unet_checkpoint,
3233
convert_ldm_vae_checkpoint,
3334
convert_ltx_transformer_checkpoint_to_diffusers,
@@ -101,6 +102,10 @@
101102
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
102103
"default_subfolder": "transformer",
103104
},
105+
"HunyuanVideoTransformer3DModel": {
106+
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
107+
"default_subfolder": "transformer",
108+
},
104109
}
105110

106111

@@ -220,6 +225,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
220225
local_files_only = kwargs.pop("local_files_only", None)
221226
subfolder = kwargs.pop("subfolder", None)
222227
revision = kwargs.pop("revision", None)
228+
config_revision = kwargs.pop("config_revision", None)
223229
torch_dtype = kwargs.pop("torch_dtype", None)
224230
quantization_config = kwargs.pop("quantization_config", None)
225231
device = kwargs.pop("device", None)
@@ -297,7 +303,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
297303
subfolder=subfolder,
298304
local_files_only=local_files_only,
299305
token=token,
300-
revision=revision,
306+
revision=config_revision,
301307
)
302308
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
303309

0 commit comments

Comments
 (0)