Skip to content

Commit 979a881

Browse files
committed
support SD3 LoRA
1 parent 8113f95 commit 979a881

File tree

13 files changed

+1030
-32
lines changed

13 files changed

+1030
-32
lines changed

README.md

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5
8080

8181
### Image Synthesis
8282

83-
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/)
83+
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/).
8484

85-
|512*512|1024*1024|2048*2048|4096*4096|
86-
|-|-|-|-|
87-
|![512](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/55f679e9-7445-4605-9315-302e93d11370)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/9087a73c-9164-4c58-b2a0-effc694143fb)|![4096](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edee9e71-fc39-4d1c-9ca9-fa52002c67ac)|
85+
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
8886

89-
|1024*1024|2048*2048|
87+
|Stable Diffusion|Stable Diffusion XL|
9088
|-|-|
91-
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|
89+
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|
90+
|Stable Diffusion 3|Hunyuan-DiT|
91+
|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|
9292

9393
### Toon Shading
9494

@@ -104,22 +104,6 @@ Video stylization without video models. [`examples/diffsynth`](./examples/diffsy
104104

105105
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
106106

107-
### Chinese Models
108-
109-
Use Hunyuan-DiT to generate images with Chinese prompts. We also support LoRA fine-tuning of this model. [`examples/hunyuan_dit`](./examples/hunyuan_dit/)
110-
111-
Prompt: 少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感
112-
113-
|1024x1024|2048x2048 (highres-fix)|
114-
|-|-|
115-
|![image_1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/2b6528cf-a229-46e9-b7dd-4a9475b07308)|![image_2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/11d264ec-966b-45c9-9804-74b60428b866)|
116-
117-
Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉
118-
119-
|Without LoRA|With LoRA|
120-
|-|-|
121-
|![image_without_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/1aa21de5-a992-4b66-b14f-caa44e08876e)|![image_with_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/83a0a41a-691f-4610-8e7b-d8e17c50a282)|
122-
123107
## Usage (in WebUI)
124108

125109
```

diffsynth/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def load_stable_diffusion_3(self, state_dict, components=None, file_path=""):
567567
if component == "sd3_text_encoder_3":
568568
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict:
569569
continue
570-
elif component == "sd3_text_encoder_1":
570+
if component == "sd3_text_encoder_1":
571571
# Add additional token embeddings to text encoder
572572
token_embeddings = [state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"]]
573573
for keyword in self.textual_inversion_dict:

diffsynth/models/sd3_dit.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,16 +199,30 @@ def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb,
199199
)
200200
return hidden_states
201201

202-
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64):
202+
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
203203
if tiled:
204204
return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
205205
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
206206
prompt_emb = self.context_embedder(prompt_emb)
207207

208208
height, width = hidden_states.shape[-2:]
209209
hidden_states = self.pos_embedder(hidden_states)
210+
211+
def create_custom_forward(module):
212+
def custom_forward(*inputs):
213+
return module(*inputs)
214+
return custom_forward
215+
210216
for block in self.blocks:
211-
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
217+
if self.training and use_gradient_checkpointing:
218+
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
219+
create_custom_forward(block),
220+
hidden_states, prompt_emb, conditioning,
221+
use_reentrant=False,
222+
)
223+
else:
224+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
225+
212226
hidden_states = self.norm_out(hidden_states, conditioning)
213227
hidden_states = self.proj_out(hidden_states)
214228
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)

diffsynth/prompts/sd3_prompter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def encode_prompt(
6969

7070
# T5
7171
if text_encoder_3 is None:
72-
prompt_emb_3 = torch.zeros((1, 256, 4096), dtype=prompt_emb_1.dtype, device=device)
72+
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device)
7373
else:
7474
prompt_emb_3 = self.encode_prompt_using_t5(pure_prompt, text_encoder_3, self.tokenizer_3, 256, device)
7575
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16

diffsynth/prompts/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,13 @@ def del_textual_inversion_tokens(self, prompt):
124124
return prompt
125125

126126
def process_prompt(self, prompt, positive=True, require_pure_prompt=False):
127+
if isinstance(prompt, list):
128+
prompt = [self.process_prompt(prompt_, positive=positive, require_pure_prompt=require_pure_prompt) for prompt_ in prompt]
129+
if require_pure_prompt:
130+
prompt, pure_prompt = [i[0] for i in prompt], [i[1] for i in prompt]
131+
return prompt, pure_prompt
132+
else:
133+
return prompt
127134
prompt, pure_prompt = self.add_textual_inversion_tokens(prompt), self.del_textual_inversion_tokens(prompt)
128135
if positive and self.translator is not None:
129136
prompt = self.translator(prompt)

diffsynth/schedulers/flow_match.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,8 @@ def add_noise(self, original_samples, noise, timestep):
4040
sigma = self.sigmas[timestep_id]
4141
sample = (1 - sigma) * original_samples + sigma * noise
4242
return sample
43+
44+
45+
def training_target(self, sample, noise, timestep):
46+
target = noise - sample
47+
return target

examples/image_synthesis/README.md

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,48 @@
11
# Image Synthesis
22

3-
Image synthesis is the base feature of DiffSynth Studio.
3+
Image synthesis is the base feature of DiffSynth Studio. We can generate images with very high resolution.
44

55
### Example: Stable Diffusion
66

7-
We can generate images with very high resolution. Please see [`sd_text_to_image.py`](./sd_text_to_image.py) for more details.
7+
Example script: [`sd_text_to_image.py`](./sd_text_to_image.py)
88

99
|512*512|1024*1024|2048*2048|4096*4096|
1010
|-|-|-|-|
1111
|![512](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/55f679e9-7445-4605-9315-302e93d11370)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/9087a73c-9164-4c58-b2a0-effc694143fb)|![4096](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edee9e71-fc39-4d1c-9ca9-fa52002c67ac)|
1212

1313
### Example: Stable Diffusion XL
1414

15-
Generate images with Stable Diffusion XL. Please see [`sdxl_text_to_image.py`](./sdxl_text_to_image.py) for more details.
15+
Example script: [`sdxl_text_to_image.py`](./sdxl_text_to_image.py)
1616

1717
|1024*1024|2048*2048|
1818
|-|-|
1919
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|
2020

2121
### Example: Stable Diffusion 3
2222

23-
Generate images with Stable Diffusion 3. High resolution is also supported in this model. See [`sd3_text_to_image.py`](./sd3_text_to_image.py).
23+
Example script: [`sd3_text_to_image.py`](./sd3_text_to_image.py)
24+
25+
LoRA Training: [`../train/stable_diffusion_3/`](../train/stable_diffusion_3/)
2426

2527
|1024*1024|2048*2048|
2628
|-|-|
2729
|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|![image_2048](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/1386c802-e580-4101-939d-f1596802df9d)|
2830

31+
### Example: Hunyuan-DiT
32+
33+
Example script: [`hunyuan_dit_text_to_image.py`](./hunyuan_dit_text_to_image.py)
34+
35+
LoRA Training: [`../train/hunyuan_dit/`](../train/hunyuan_dit/)
36+
37+
|1024*1024|2048*2048|
38+
|-|-|
39+
|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|![image_2048](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/87919ea8-d428-4963-8257-da05f3901bbb)|
40+
2941
### Example: Stable Diffusion XL Turbo
3042

31-
Generate images with Stable Diffusion XL Turbo. You can see [`sdxl_turbo.py`](./sdxl_turbo.py) for more details, but we highly recommend you to use it in the WebUI.
43+
Example script: [`sdxl_turbo.py`](./sdxl_turbo.py)
44+
45+
We highly recommend you to use this model in the WebUI.
3246

3347
|"black car"|"red car"|
3448
|-|-|
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from diffsynth import ModelManager, HunyuanDiTImagePipeline, download_models
2+
import torch
3+
4+
5+
# Download models (automatically)
6+
# `models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin`: [link](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/resolve/main/t2i/clip_text_encoder/pytorch_model.bin)
7+
# `models/HunyuanDiT/t2i/mt5/pytorch_model.bin`: [link](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/resolve/main/t2i/mt5/pytorch_model.bin)
8+
# `models/HunyuanDiT/t2i/model/pytorch_model_ema.pt`: [link](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/resolve/main/t2i/model/pytorch_model_ema.pt)
9+
# `models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin`: [link](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/resolve/main/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin)
10+
download_models(["HunyuanDiT"])
11+
12+
# Load models
13+
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
14+
model_manager.load_models([
15+
"models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin",
16+
"models/HunyuanDiT/t2i/mt5/pytorch_model.bin",
17+
"models/HunyuanDiT/t2i/model/pytorch_model_ema.pt",
18+
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"
19+
])
20+
pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
21+
22+
prompt = "一幅充满诗意美感的全身肖像画,画中一位银发、蓝色眼睛、身穿蓝色连衣裙的少女漂浮在水下,周围是光彩的气泡,和煦的阳光透过水面折射进水下"
23+
negative_prompt = "错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,"
24+
25+
# Enjoy!
26+
torch.manual_seed(0)
27+
image = pipe(
28+
prompt=prompt,
29+
negative_prompt=negative_prompt,
30+
num_inference_steps=50, height=1024, width=1024,
31+
)
32+
image.save("image_1024.png")
33+
34+
# Highres fix
35+
image = pipe(
36+
prompt=prompt,
37+
negative_prompt=negative_prompt,
38+
input_image=image.resize((2048, 2048)),
39+
num_inference_steps=50, height=2048, width=2048,
40+
denoising_strength=0.4, tiled=True,
41+
)
42+
image.save("image_2048.png")

examples/image_synthesis/sd3_text_to_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# `models/stable_diffusion_3/sd3_medium_incl_clips.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/sd3_medium_incl_clips.safetensors)
77
download_models(["StableDiffusion3"])
88
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
9-
file_path_list=["models/stable_diffusion_3/sd3_medium_incl_clips_t5xxlfp16.safetensors"])
9+
file_path_list=["models/stable_diffusion_3/sd3_medium_incl_clips.safetensors"])
1010
pipe = SD3ImagePipeline.from_model_manager(model_manager)
1111

1212

0 commit comments

Comments
 (0)