Skip to content

Commit 9e4b4a7

Browse files
authored
Merge branch 'main' into modify_is_sequential_off_load
2 parents c6ed536 + 4e57aef commit 9e4b4a7

File tree

12 files changed

+410
-306
lines changed

12 files changed

+410
-306
lines changed

docs/source/en/api/models/autoencoderkl.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ The abstract from the paper is:
2121
## Loading from the original format
2222

2323
By default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
24-
from the original format using [`FromOriginalVAEMixin.from_single_file`] as follows:
24+
from the original format using [`FromOriginalModelMixin.from_single_file`] as follows:
2525

2626
```py
2727
from diffusers import AutoencoderKL

docs/source/en/api/models/controlnet.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ The abstract from the paper is:
2121
## Loading from the original format
2222

2323
By default the [`ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
24-
from the original format using [`FromOriginalControlnetMixin.from_single_file`] as follows:
24+
from the original format using [`FromOriginalModelMixin.from_single_file`] as follows:
2525

2626
```py
2727
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel

docs/source/en/api/pipelines/hunyuandit.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
3434

3535
</Tip>
3636

37+
<Tip>
38+
39+
You can further improve generation quality by passing the generated image from [`HungyuanDiTPipeline`] to the [SDXL refiner](../../using-diffusers/sdxl#base-to-refiner-model) model.
40+
41+
</Tip>
42+
3743
## Optimization
3844

3945
You can optimize the pipeline's runtime and memory consumption with torch.compile and feed-forward chunking. To learn about other optimization methods, check out the [Speed up inference](../../optimization/fp16) and [Reduce memory usage](../../optimization/memory) guides.

docs/source/en/api/pipelines/pixart_sigma.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
3737

3838
</Tip>
3939

40+
<Tip>
41+
42+
You can further improve generation quality by passing the generated image from [`PixArtSigmaPipeline`] to the [SDXL refiner](../../using-diffusers/sdxl#base-to-refiner-model) model.
43+
44+
</Tip>
45+
4046
## Inference with under 8GB GPU VRAM
4147

4248
Run the [`PixArtSigmaPipeline`] with under 8GB GPU VRAM by loading the text encoder in 8-bit precision. Let's walk through a full-fledged example.

docs/source/en/using-diffusers/sdxl.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,12 @@ refiner = DiffusionPipeline.from_pretrained(
285285
).to("cuda")
286286
```
287287

288+
<Tip warning={true}>
289+
290+
You can use SDXL refiner with a different base model. For example, you can use the [Hunyuan-DiT](../../api/pipelines/hunyuandit) or [PixArt-Sigma](../../api/pipelines/pixart_sigma) pipelines to generate images with better prompt adherence. Once you have generated an image, you can pass it to the SDXL refiner model to enhance final generation quality.
291+
292+
</Tip>
293+
288294
Generate an image from the base model, and set the model output to **latent** space:
289295

290296
```py
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
import safetensors
22+
23+
24+
sys.path.append("..")
25+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
26+
27+
28+
logging.basicConfig(level=logging.DEBUG)
29+
30+
logger = logging.getLogger()
31+
stream_handler = logging.StreamHandler(sys.stdout)
32+
logger.addHandler(stream_handler)
33+
34+
35+
class DreamBoothLoRASD3(ExamplesTestsAccelerate):
36+
instance_data_dir = "docs/source/en/imgs"
37+
instance_prompt = "photo"
38+
pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
39+
script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py"
40+
41+
def test_dreambooth_lora_sd3(self):
42+
with tempfile.TemporaryDirectory() as tmpdir:
43+
test_args = f"""
44+
{self.script_path}
45+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
46+
--instance_data_dir {self.instance_data_dir}
47+
--instance_prompt {self.instance_prompt}
48+
--resolution 64
49+
--train_batch_size 1
50+
--gradient_accumulation_steps 1
51+
--max_train_steps 2
52+
--learning_rate 5.0e-04
53+
--scale_lr
54+
--lr_scheduler constant
55+
--lr_warmup_steps 0
56+
--output_dir {tmpdir}
57+
""".split()
58+
59+
run_command(self._launch_args + test_args)
60+
# save_pretrained smoke test
61+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
62+
63+
# make sure the state_dict has the correct naming in the parameters.
64+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
65+
is_lora = all("lora" in k for k in lora_state_dict.keys())
66+
self.assertTrue(is_lora)
67+
68+
# when not training the text encoder, all the parameters in the state dict should start
69+
# with `"transformer"` in their names.
70+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
71+
self.assertTrue(starts_with_transformer)
72+
73+
def test_dreambooth_lora_text_encoder_sd3(self):
74+
with tempfile.TemporaryDirectory() as tmpdir:
75+
test_args = f"""
76+
{self.script_path}
77+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
78+
--instance_data_dir {self.instance_data_dir}
79+
--instance_prompt {self.instance_prompt}
80+
--resolution 64
81+
--train_batch_size 1
82+
--train_text_encoder
83+
--gradient_accumulation_steps 1
84+
--max_train_steps 2
85+
--learning_rate 5.0e-04
86+
--scale_lr
87+
--lr_scheduler constant
88+
--lr_warmup_steps 0
89+
--output_dir {tmpdir}
90+
""".split()
91+
92+
run_command(self._launch_args + test_args)
93+
# save_pretrained smoke test
94+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
95+
96+
# make sure the state_dict has the correct naming in the parameters.
97+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
98+
is_lora = all("lora" in k for k in lora_state_dict.keys())
99+
self.assertTrue(is_lora)
100+
101+
starts_with_expected_prefix = all(
102+
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
103+
)
104+
self.assertTrue(starts_with_expected_prefix)
105+
106+
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
107+
with tempfile.TemporaryDirectory() as tmpdir:
108+
test_args = f"""
109+
{self.script_path}
110+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
111+
--instance_data_dir={self.instance_data_dir}
112+
--output_dir={tmpdir}
113+
--instance_prompt={self.instance_prompt}
114+
--resolution=64
115+
--train_batch_size=1
116+
--gradient_accumulation_steps=1
117+
--max_train_steps=6
118+
--checkpoints_total_limit=2
119+
--checkpointing_steps=2
120+
""".split()
121+
122+
run_command(self._launch_args + test_args)
123+
124+
self.assertEqual(
125+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
126+
{"checkpoint-4", "checkpoint-6"},
127+
)
128+
129+
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
130+
with tempfile.TemporaryDirectory() as tmpdir:
131+
test_args = f"""
132+
{self.script_path}
133+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
134+
--instance_data_dir={self.instance_data_dir}
135+
--output_dir={tmpdir}
136+
--instance_prompt={self.instance_prompt}
137+
--resolution=64
138+
--train_batch_size=1
139+
--gradient_accumulation_steps=1
140+
--max_train_steps=4
141+
--checkpointing_steps=2
142+
""".split()
143+
144+
run_command(self._launch_args + test_args)
145+
146+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
147+
148+
resume_run_args = f"""
149+
{self.script_path}
150+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
151+
--instance_data_dir={self.instance_data_dir}
152+
--output_dir={tmpdir}
153+
--instance_prompt={self.instance_prompt}
154+
--resolution=64
155+
--train_batch_size=1
156+
--gradient_accumulation_steps=1
157+
--max_train_steps=8
158+
--checkpointing_steps=2
159+
--resume_from_checkpoint=checkpoint-4
160+
--checkpoints_total_limit=2
161+
""".split()
162+
163+
run_command(self._launch_args + resume_run_args)
164+
165+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

0 commit comments

Comments
 (0)