Skip to content

Commit 9920b8d

Browse files
committed
support SD3
1 parent 237daa2 commit 9920b8d

26 files changed

+329648
-2
lines changed

diffsynth/models/__init__.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
from .sdxl_vae_decoder import SDXLVAEDecoder
1717
from .sdxl_vae_encoder import SDXLVAEEncoder
1818

19+
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
20+
from .sd3_dit import SD3DiT
21+
from .sd3_vae_decoder import SD3VAEDecoder
22+
from .sd3_vae_encoder import SD3VAEEncoder
23+
1924
from .sd_controlnet import SDControlNet
2025

2126
from .sd_motion import SDMotionModel
@@ -90,6 +95,13 @@
9095
"StableDiffusionXL_Turbo": [
9196
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
9297
],
98+
# Stable Diffusion 3
99+
"StableDiffusion3": [
100+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
101+
],
102+
"StableDiffusion3_without_T5": [
103+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
104+
],
93105
# ControlNet
94106
"ControlNet_v11f1p_sd15_depth": [
95107
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
@@ -171,6 +183,8 @@
171183
"opus-mt-zh-en",
172184
"IP-Adapter-SD",
173185
"IP-Adapter-SDXL",
186+
"StableDiffusion3",
187+
"StableDiffusion3_without_T5"
174188
]
175189
Preset_model_website: TypeAlias = Literal[
176190
"HuggingFace",
@@ -297,7 +311,8 @@ def is_hunyuan_dit_clip_text_encoder(self, state_dict):
297311

298312
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
299313
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
300-
return param_name in state_dict
314+
param_name_ = "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
315+
return param_name in state_dict and param_name_ in state_dict
301316

302317
def is_hunyuan_dit(self, state_dict):
303318
param_name = "final_layer.adaLN_modulation.1.weight"
@@ -311,6 +326,23 @@ def is_ExVideo_StableVideoDiffusion(self, state_dict):
311326
param_name = "blocks.185.positional_embedding.embeddings"
312327
return param_name in state_dict
313328

329+
def is_stable_diffusion_3(self, state_dict):
330+
param_names = [
331+
"text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight",
332+
"text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight",
333+
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight",
334+
"first_stage_model.encoder.mid.block_2.norm2.weight",
335+
"first_stage_model.decoder.mid.block_2.norm2.weight",
336+
]
337+
for param_name in param_names:
338+
if param_name not in state_dict:
339+
return False
340+
return True
341+
342+
def is_stable_diffusion_3_t5(self, state_dict):
343+
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
344+
return param_name in state_dict
345+
314346
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
315347
component_dict = {
316348
"image_encoder": SVDImageEncoder,
@@ -520,6 +552,34 @@ def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
520552
self.model["unet"].load_state_dict(state_dict, strict=False)
521553
self.model["unet"].to(self.torch_dtype).to(self.device)
522554

555+
def load_stable_diffusion_3(self, state_dict, components=None, file_path=""):
556+
component_dict = {
557+
"sd3_text_encoder_1": SD3TextEncoder1,
558+
"sd3_text_encoder_2": SD3TextEncoder2,
559+
"sd3_text_encoder_3": SD3TextEncoder3,
560+
"sd3_dit": SD3DiT,
561+
"sd3_vae_decoder": SD3VAEDecoder,
562+
"sd3_vae_encoder": SD3VAEEncoder,
563+
}
564+
if components is None:
565+
components = ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_decoder", "sd3_vae_encoder"]
566+
for component in components:
567+
if component == "sd3_text_encoder_3":
568+
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict:
569+
continue
570+
self.model[component] = component_dict[component]()
571+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
572+
self.model[component].to(self.torch_dtype).to(self.device)
573+
self.model_path[component] = file_path
574+
575+
def load_stable_diffusion_3_t5(self, state_dict, file_path=""):
576+
component = "sd3_text_encoder_3"
577+
model = SD3TextEncoder3()
578+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
579+
model.to(self.torch_dtype).to(self.device)
580+
self.model[component] = model
581+
self.model_path[component] = file_path
582+
523583
def search_for_embeddings(self, state_dict):
524584
embeddings = []
525585
for k in state_dict:
@@ -587,6 +647,10 @@ def load_model(self, file_path, components=None, lora_alphas=[]):
587647
self.load_diffusers_vae(state_dict, file_path=file_path)
588648
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
589649
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
650+
elif self.is_stable_diffusion_3(state_dict):
651+
self.load_stable_diffusion_3(state_dict, components=components, file_path=file_path)
652+
elif self.is_stable_diffusion_3_t5(state_dict):
653+
self.load_stable_diffusion_3_t5(state_dict, file_path=file_path)
590654

591655
def load_models(self, file_path_list, lora_alphas=[]):
592656
for file_path in file_path_list:

0 commit comments

Comments
 (0)