Skip to content

Commit 8be4fad

Browse files
authored
Merge pull request #94 from modelscope/Artiprocher-sd3
support SD3
2 parents 237daa2 + 8113f95 commit 8be4fad

30 files changed

+329741
-5
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ DiffSynth Studio is a Diffusion engine. We have restructured architectures inclu
88
Until now, DiffSynth Studio has supported the following models:
99

1010
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
11+
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
1112
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
1213
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
1314
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)

diffsynth/models/__init__.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
from .sdxl_vae_decoder import SDXLVAEDecoder
1717
from .sdxl_vae_encoder import SDXLVAEEncoder
1818

19+
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
20+
from .sd3_dit import SD3DiT
21+
from .sd3_vae_decoder import SD3VAEDecoder
22+
from .sd3_vae_encoder import SD3VAEEncoder
23+
1924
from .sd_controlnet import SDControlNet
2025

2126
from .sd_motion import SDMotionModel
@@ -90,6 +95,13 @@
9095
"StableDiffusionXL_Turbo": [
9196
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
9297
],
98+
# Stable Diffusion 3
99+
"StableDiffusion3": [
100+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
101+
],
102+
"StableDiffusion3_without_T5": [
103+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
104+
],
93105
# ControlNet
94106
"ControlNet_v11f1p_sd15_depth": [
95107
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
@@ -171,6 +183,8 @@
171183
"opus-mt-zh-en",
172184
"IP-Adapter-SD",
173185
"IP-Adapter-SDXL",
186+
"StableDiffusion3",
187+
"StableDiffusion3_without_T5"
174188
]
175189
Preset_model_website: TypeAlias = Literal[
176190
"HuggingFace",
@@ -297,7 +311,8 @@ def is_hunyuan_dit_clip_text_encoder(self, state_dict):
297311

298312
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
299313
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
300-
return param_name in state_dict
314+
param_name_ = "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
315+
return param_name in state_dict and param_name_ in state_dict
301316

302317
def is_hunyuan_dit(self, state_dict):
303318
param_name = "final_layer.adaLN_modulation.1.weight"
@@ -311,6 +326,23 @@ def is_ExVideo_StableVideoDiffusion(self, state_dict):
311326
param_name = "blocks.185.positional_embedding.embeddings"
312327
return param_name in state_dict
313328

329+
def is_stable_diffusion_3(self, state_dict):
330+
param_names = [
331+
"text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight",
332+
"text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight",
333+
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight",
334+
"first_stage_model.encoder.mid.block_2.norm2.weight",
335+
"first_stage_model.decoder.mid.block_2.norm2.weight",
336+
]
337+
for param_name in param_names:
338+
if param_name not in state_dict:
339+
return False
340+
return True
341+
342+
def is_stable_diffusion_3_t5(self, state_dict):
343+
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
344+
return param_name in state_dict
345+
314346
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
315347
component_dict = {
316348
"image_encoder": SVDImageEncoder,
@@ -520,6 +552,46 @@ def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
520552
self.model["unet"].load_state_dict(state_dict, strict=False)
521553
self.model["unet"].to(self.torch_dtype).to(self.device)
522554

555+
def load_stable_diffusion_3(self, state_dict, components=None, file_path=""):
556+
component_dict = {
557+
"sd3_text_encoder_1": SD3TextEncoder1,
558+
"sd3_text_encoder_2": SD3TextEncoder2,
559+
"sd3_text_encoder_3": SD3TextEncoder3,
560+
"sd3_dit": SD3DiT,
561+
"sd3_vae_decoder": SD3VAEDecoder,
562+
"sd3_vae_encoder": SD3VAEEncoder,
563+
}
564+
if components is None:
565+
components = ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_decoder", "sd3_vae_encoder"]
566+
for component in components:
567+
if component == "sd3_text_encoder_3":
568+
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict:
569+
continue
570+
elif component == "sd3_text_encoder_1":
571+
# Add additional token embeddings to text encoder
572+
token_embeddings = [state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"]]
573+
for keyword in self.textual_inversion_dict:
574+
_, embeddings = self.textual_inversion_dict[keyword]
575+
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
576+
token_embeddings = torch.concat(token_embeddings, dim=0)
577+
state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
578+
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
579+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
580+
self.model[component].to(self.torch_dtype).to(self.device)
581+
else:
582+
self.model[component] = component_dict[component]()
583+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
584+
self.model[component].to(self.torch_dtype).to(self.device)
585+
self.model_path[component] = file_path
586+
587+
def load_stable_diffusion_3_t5(self, state_dict, file_path=""):
588+
component = "sd3_text_encoder_3"
589+
model = SD3TextEncoder3()
590+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
591+
model.to(self.torch_dtype).to(self.device)
592+
self.model[component] = model
593+
self.model_path[component] = file_path
594+
523595
def search_for_embeddings(self, state_dict):
524596
embeddings = []
525597
for k in state_dict:
@@ -587,6 +659,10 @@ def load_model(self, file_path, components=None, lora_alphas=[]):
587659
self.load_diffusers_vae(state_dict, file_path=file_path)
588660
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
589661
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
662+
elif self.is_stable_diffusion_3(state_dict):
663+
self.load_stable_diffusion_3(state_dict, components=components, file_path=file_path)
664+
elif self.is_stable_diffusion_3_t5(state_dict):
665+
self.load_stable_diffusion_3_t5(state_dict, file_path=file_path)
590666

591667
def load_models(self, file_path_list, lora_alphas=[]):
592668
for file_path in file_path_list:

0 commit comments

Comments
 (0)