Skip to content

Commit c3d899d

Browse files
authored
Merge pull request #101 from modelscope/Artiprocher-sd3-lora
Support SD3 LoRA
2 parents 8be4fad + 6e03ee2 commit c3d899d

File tree

13 files changed

+554
-127
lines changed

13 files changed

+554
-127
lines changed

README.md

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5
8080

8181
### Image Synthesis
8282

83-
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/)
83+
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/).
8484

85-
|512*512|1024*1024|2048*2048|4096*4096|
86-
|-|-|-|-|
87-
|![512](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/55f679e9-7445-4605-9315-302e93d11370)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/9087a73c-9164-4c58-b2a0-effc694143fb)|![4096](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edee9e71-fc39-4d1c-9ca9-fa52002c67ac)|
85+
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
8886

89-
|1024*1024|2048*2048|
87+
|Stable Diffusion|Stable Diffusion XL|
9088
|-|-|
91-
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|
89+
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|
90+
|Stable Diffusion 3|Hunyuan-DiT|
91+
|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|
9292

9393
### Toon Shading
9494

@@ -104,22 +104,6 @@ Video stylization without video models. [`examples/diffsynth`](./examples/diffsy
104104

105105
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
106106

107-
### Chinese Models
108-
109-
Use Hunyuan-DiT to generate images with Chinese prompts. We also support LoRA fine-tuning of this model. [`examples/hunyuan_dit`](./examples/hunyuan_dit/)
110-
111-
Prompt: 少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感
112-
113-
|1024x1024|2048x2048 (highres-fix)|
114-
|-|-|
115-
|![image_1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/2b6528cf-a229-46e9-b7dd-4a9475b07308)|![image_2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/11d264ec-966b-45c9-9804-74b60428b866)|
116-
117-
Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉
118-
119-
|Without LoRA|With LoRA|
120-
|-|-|
121-
|![image_without_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/1aa21de5-a992-4b66-b14f-caa44e08876e)|![image_with_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/83a0a41a-691f-4610-8e7b-d8e17c50a282)|
122-
123107
## Usage (in WebUI)
124108

125109
```

diffsynth/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def load_stable_diffusion_3(self, state_dict, components=None, file_path=""):
567567
if component == "sd3_text_encoder_3":
568568
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict:
569569
continue
570-
elif component == "sd3_text_encoder_1":
570+
if component == "sd3_text_encoder_1":
571571
# Add additional token embeddings to text encoder
572572
token_embeddings = [state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"]]
573573
for keyword in self.textual_inversion_dict:

diffsynth/models/sd3_dit.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,16 +199,30 @@ def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb,
199199
)
200200
return hidden_states
201201

202-
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64):
202+
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
203203
if tiled:
204204
return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
205205
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
206206
prompt_emb = self.context_embedder(prompt_emb)
207207

208208
height, width = hidden_states.shape[-2:]
209209
hidden_states = self.pos_embedder(hidden_states)
210+
211+
def create_custom_forward(module):
212+
def custom_forward(*inputs):
213+
return module(*inputs)
214+
return custom_forward
215+
210216
for block in self.blocks:
211-
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
217+
if self.training and use_gradient_checkpointing:
218+
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
219+
create_custom_forward(block),
220+
hidden_states, prompt_emb, conditioning,
221+
use_reentrant=False,
222+
)
223+
else:
224+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
225+
212226
hidden_states = self.norm_out(hidden_states, conditioning)
213227
hidden_states = self.proj_out(hidden_states)
214228
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)

diffsynth/prompts/sd3_prompter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def encode_prompt(
6969

7070
# T5
7171
if text_encoder_3 is None:
72-
prompt_emb_3 = torch.zeros((1, 256, 4096), dtype=prompt_emb_1.dtype, device=device)
72+
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device)
7373
else:
7474
prompt_emb_3 = self.encode_prompt_using_t5(pure_prompt, text_encoder_3, self.tokenizer_3, 256, device)
7575
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16

diffsynth/prompts/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,13 @@ def del_textual_inversion_tokens(self, prompt):
124124
return prompt
125125

126126
def process_prompt(self, prompt, positive=True, require_pure_prompt=False):
127+
if isinstance(prompt, list):
128+
prompt = [self.process_prompt(prompt_, positive=positive, require_pure_prompt=require_pure_prompt) for prompt_ in prompt]
129+
if require_pure_prompt:
130+
prompt, pure_prompt = [i[0] for i in prompt], [i[1] for i in prompt]
131+
return prompt, pure_prompt
132+
else:
133+
return prompt
127134
prompt, pure_prompt = self.add_textual_inversion_tokens(prompt), self.del_textual_inversion_tokens(prompt)
128135
if positive and self.translator is not None:
129136
prompt = self.translator(prompt)

diffsynth/schedulers/flow_match.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,8 @@ def add_noise(self, original_samples, noise, timestep):
4040
sigma = self.sigmas[timestep_id]
4141
sample = (1 - sigma) * original_samples + sigma * noise
4242
return sample
43+
44+
45+
def training_target(self, sample, noise, timestep):
46+
target = noise - sample
47+
return target

examples/image_synthesis/README.md

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,48 @@
11
# Image Synthesis
22

3-
Image synthesis is the base feature of DiffSynth Studio.
3+
Image synthesis is the base feature of DiffSynth Studio. We can generate images with very high resolution.
44

55
### Example: Stable Diffusion
66

7-
We can generate images with very high resolution. Please see [`sd_text_to_image.py`](./sd_text_to_image.py) for more details.
7+
Example script: [`sd_text_to_image.py`](./sd_text_to_image.py)
88

99
|512*512|1024*1024|2048*2048|4096*4096|
1010
|-|-|-|-|
1111
|![512](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/55f679e9-7445-4605-9315-302e93d11370)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/9087a73c-9164-4c58-b2a0-effc694143fb)|![4096](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edee9e71-fc39-4d1c-9ca9-fa52002c67ac)|
1212

1313
### Example: Stable Diffusion XL
1414

15-
Generate images with Stable Diffusion XL. Please see [`sdxl_text_to_image.py`](./sdxl_text_to_image.py) for more details.
15+
Example script: [`sdxl_text_to_image.py`](./sdxl_text_to_image.py)
1616

1717
|1024*1024|2048*2048|
1818
|-|-|
1919
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|
2020

2121
### Example: Stable Diffusion 3
2222

23-
Generate images with Stable Diffusion 3. High resolution is also supported in this model. See [`sd3_text_to_image.py`](./sd3_text_to_image.py).
23+
Example script: [`sd3_text_to_image.py`](./sd3_text_to_image.py)
24+
25+
LoRA Training: [`../train/stable_diffusion_3/`](../train/stable_diffusion_3/)
2426

2527
|1024*1024|2048*2048|
2628
|-|-|
2729
|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|![image_2048](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/1386c802-e580-4101-939d-f1596802df9d)|
2830

31+
### Example: Hunyuan-DiT
32+
33+
Example script: [`hunyuan_dit_text_to_image.py`](./hunyuan_dit_text_to_image.py)
34+
35+
LoRA Training: [`../train/hunyuan_dit/`](../train/hunyuan_dit/)
36+
37+
|1024*1024|2048*2048|
38+
|-|-|
39+
|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|![image_2048](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/87919ea8-d428-4963-8257-da05f3901bbb)|
40+
2941
### Example: Stable Diffusion XL Turbo
3042

31-
Generate images with Stable Diffusion XL Turbo. You can see [`sdxl_turbo.py`](./sdxl_turbo.py) for more details, but we highly recommend you to use it in the WebUI.
43+
Example script: [`sdxl_turbo.py`](./sdxl_turbo.py)
44+
45+
We highly recommend you to use this model in the WebUI.
3246

3347
|"black car"|"red car"|
3448
|-|-|
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from diffsynth import ModelManager, HunyuanDiTImagePipeline, download_models
2+
import torch
3+
4+
5+
# Download models (automatically)
6+
# `models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin`: [link](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/resolve/main/t2i/clip_text_encoder/pytorch_model.bin)
7+
# `models/HunyuanDiT/t2i/mt5/pytorch_model.bin`: [link](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/resolve/main/t2i/mt5/pytorch_model.bin)
8+
# `models/HunyuanDiT/t2i/model/pytorch_model_ema.pt`: [link](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/resolve/main/t2i/model/pytorch_model_ema.pt)
9+
# `models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin`: [link](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/resolve/main/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin)
10+
download_models(["HunyuanDiT"])
11+
12+
# Load models
13+
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
14+
model_manager.load_models([
15+
"models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin",
16+
"models/HunyuanDiT/t2i/mt5/pytorch_model.bin",
17+
"models/HunyuanDiT/t2i/model/pytorch_model_ema.pt",
18+
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"
19+
])
20+
pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
21+
22+
prompt = "一幅充满诗意美感的全身肖像画,画中一位银发、蓝色眼睛、身穿蓝色连衣裙的少女漂浮在水下,周围是光彩的气泡,和煦的阳光透过水面折射进水下"
23+
negative_prompt = "错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,"
24+
25+
# Enjoy!
26+
torch.manual_seed(0)
27+
image = pipe(
28+
prompt=prompt,
29+
negative_prompt=negative_prompt,
30+
num_inference_steps=50, height=1024, width=1024,
31+
)
32+
image.save("image_1024.png")
33+
34+
# Highres fix
35+
image = pipe(
36+
prompt=prompt,
37+
negative_prompt=negative_prompt,
38+
input_image=image.resize((2048, 2048)),
39+
num_inference_steps=50, height=2048, width=2048,
40+
denoising_strength=0.4, tiled=True,
41+
)
42+
image.save("image_2048.png")

examples/image_synthesis/sd3_text_to_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
# Download models (automatically)
66
# `models/stable_diffusion_3/sd3_medium_incl_clips.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/sd3_medium_incl_clips.safetensors)
7-
download_models(["StableDiffusion3"])
7+
download_models(["StableDiffusion3_without_T5"])
88
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
9-
file_path_list=["models/stable_diffusion_3/sd3_medium_incl_clips_t5xxlfp16.safetensors"])
9+
file_path_list=["models/stable_diffusion_3/sd3_medium_incl_clips.safetensors"])
1010
pipe = SD3ImagePipeline.from_model_manager(model_manager)
1111

1212

examples/hunyuan_dit/README.md renamed to examples/train/hunyuan_dit/README.md

Lines changed: 2 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -28,99 +28,6 @@ from diffsynth import download_models
2828
download_models(["HunyuanDiT"])
2929
```
3030

31-
## Inference
32-
33-
### Text-to-image with highres-fix
34-
35-
The original resolution of Hunyuan DiT is 1024x1024. If you want to use larger resolutions, please use highres-fix.
36-
37-
Hunyuan DiT is also supported in our UI.
38-
39-
```python
40-
from diffsynth import ModelManager, HunyuanDiTImagePipeline
41-
import torch
42-
43-
44-
# Load models
45-
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
46-
model_manager.load_models([
47-
"models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin",
48-
"models/HunyuanDiT/t2i/mt5/pytorch_model.bin",
49-
"models/HunyuanDiT/t2i/model/pytorch_model_ema.pt",
50-
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"
51-
])
52-
pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
53-
54-
# Enjoy!
55-
torch.manual_seed(0)
56-
image = pipe(
57-
prompt="少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感",
58-
negative_prompt="错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,",
59-
num_inference_steps=50, height=1024, width=1024,
60-
)
61-
image.save("image_1024.png")
62-
63-
# Highres fix
64-
image = pipe(
65-
prompt="少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感",
66-
negative_prompt="错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,",
67-
input_image=image.resize((2048, 2048)),
68-
num_inference_steps=50, height=2048, width=2048,
69-
cfg_scale=3.0, denoising_strength=0.5, tiled=True,
70-
)
71-
image.save("image_2048.png")
72-
```
73-
74-
Prompt: 少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感
75-
76-
|1024x1024|2048x2048 (highres-fix)|
77-
|-|-|
78-
|![image_1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/2b6528cf-a229-46e9-b7dd-4a9475b07308)|![image_2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/11d264ec-966b-45c9-9804-74b60428b866)|
79-
80-
### In-context reference (experimental)
81-
82-
This feature is similar to the "reference-only" mode in ControlNets. By extending the self-attention layer, the content in the reference image can be retained in the new image. Any number of reference images are supported, and the influence from each reference image can be controled by independent `reference_strengths` parameters.
83-
84-
```python
85-
from diffsynth import ModelManager, HunyuanDiTImagePipeline
86-
import torch
87-
88-
89-
# Load models
90-
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
91-
model_manager.load_models([
92-
"models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin",
93-
"models/HunyuanDiT/t2i/mt5/pytorch_model.bin",
94-
"models/HunyuanDiT/t2i/model/pytorch_model_ema.pt",
95-
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"
96-
])
97-
pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
98-
99-
# Generate an image as reference
100-
torch.manual_seed(0)
101-
reference_image = pipe(
102-
prompt="梵高,星空,油画,明亮",
103-
negative_prompt="",
104-
num_inference_steps=50, height=1024, width=1024,
105-
)
106-
reference_image.save("image_reference.png")
107-
108-
# Generate a new image with reference
109-
image = pipe(
110-
prompt="层峦叠嶂的山脉,郁郁葱葱的森林,皎洁明亮的月光,夜色下的自然美景",
111-
negative_prompt="",
112-
reference_images=[reference_image], reference_strengths=[0.4],
113-
num_inference_steps=50, height=1024, width=1024,
114-
)
115-
image.save("image_with_reference.png")
116-
```
117-
118-
Prompt: 层峦叠嶂的山脉,郁郁葱葱的森林,皎洁明亮的月光,夜色下的自然美景
119-
120-
|Reference image|Generated new image|
121-
|-|-|
122-
|![image_reference](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/99b0189d-6175-4842-b480-3c0d2f9f7e17)|![image_with_reference](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/8e41dddb-f302-4a2d-9e52-5487d1f47ae6)|
123-
12431
## Train
12532

12633
### Install training dependency
@@ -254,7 +161,8 @@ pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
254161

255162
# Generate an image with lora
256163
pipe.dit = load_lora(
257-
pipe.dit, lora_rank=4, lora_alpha=4.0,
164+
pipe.dit,
165+
lora_rank=4, lora_alpha=4.0, # The two parameters should be consistent with those in your training script.
258166
lora_path="path/to/your/lora/model/lightning_logs/version_x/checkpoints/epoch=x-step=xxx.ckpt"
259167
)
260168
torch.manual_seed(0)

0 commit comments

Comments
 (0)