1616from .sdxl_vae_decoder import SDXLVAEDecoder
1717from .sdxl_vae_encoder import SDXLVAEEncoder
1818
19+ from .sd3_text_encoder import SD3TextEncoder1 , SD3TextEncoder2 , SD3TextEncoder3
20+ from .sd3_dit import SD3DiT
21+ from .sd3_vae_decoder import SD3VAEDecoder
22+ from .sd3_vae_encoder import SD3VAEEncoder
23+
1924from .sd_controlnet import SDControlNet
2025
2126from .sd_motion import SDMotionModel
9095 "StableDiffusionXL_Turbo" : [
9196 ("AI-ModelScope/sdxl-turbo" , "sd_xl_turbo_1.0_fp16.safetensors" , "models/stable_diffusion_xl_turbo" ),
9297 ],
98+ # Stable Diffusion 3
99+ "StableDiffusion3" : [
100+ ("AI-ModelScope/stable-diffusion-3-medium" , "sd3_medium_incl_clips_t5xxlfp16.safetensors" , "models/stable_diffusion_3" ),
101+ ],
102+ "StableDiffusion3_without_T5" : [
103+ ("AI-ModelScope/stable-diffusion-3-medium" , "sd3_medium_incl_clips.safetensors" , "models/stable_diffusion_3" ),
104+ ],
93105 # ControlNet
94106 "ControlNet_v11f1p_sd15_depth" : [
95107 ("AI-ModelScope/ControlNet-v1-1" , "control_v11f1p_sd15_depth.pth" , "models/ControlNet" ),
171183 "opus-mt-zh-en" ,
172184 "IP-Adapter-SD" ,
173185 "IP-Adapter-SDXL" ,
186+ "StableDiffusion3" ,
187+ "StableDiffusion3_without_T5"
174188]
175189Preset_model_website : TypeAlias = Literal [
176190 "HuggingFace" ,
@@ -297,7 +311,8 @@ def is_hunyuan_dit_clip_text_encoder(self, state_dict):
297311
298312 def is_hunyuan_dit_t5_text_encoder (self , state_dict ):
299313 param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
300- return param_name in state_dict
314+ param_name_ = "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
315+ return param_name in state_dict and param_name_ in state_dict
301316
302317 def is_hunyuan_dit (self , state_dict ):
303318 param_name = "final_layer.adaLN_modulation.1.weight"
@@ -311,6 +326,23 @@ def is_ExVideo_StableVideoDiffusion(self, state_dict):
311326 param_name = "blocks.185.positional_embedding.embeddings"
312327 return param_name in state_dict
313328
329+ def is_stable_diffusion_3 (self , state_dict ):
330+ param_names = [
331+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight" ,
332+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight" ,
333+ "model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight" ,
334+ "first_stage_model.encoder.mid.block_2.norm2.weight" ,
335+ "first_stage_model.decoder.mid.block_2.norm2.weight" ,
336+ ]
337+ for param_name in param_names :
338+ if param_name not in state_dict :
339+ return False
340+ return True
341+
342+ def is_stable_diffusion_3_t5 (self , state_dict ):
343+ param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
344+ return param_name in state_dict
345+
314346 def load_stable_video_diffusion (self , state_dict , components = None , file_path = "" , add_positional_conv = None ):
315347 component_dict = {
316348 "image_encoder" : SVDImageEncoder ,
@@ -520,6 +552,46 @@ def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
520552 self .model ["unet" ].load_state_dict (state_dict , strict = False )
521553 self .model ["unet" ].to (self .torch_dtype ).to (self .device )
522554
555+ def load_stable_diffusion_3 (self , state_dict , components = None , file_path = "" ):
556+ component_dict = {
557+ "sd3_text_encoder_1" : SD3TextEncoder1 ,
558+ "sd3_text_encoder_2" : SD3TextEncoder2 ,
559+ "sd3_text_encoder_3" : SD3TextEncoder3 ,
560+ "sd3_dit" : SD3DiT ,
561+ "sd3_vae_decoder" : SD3VAEDecoder ,
562+ "sd3_vae_encoder" : SD3VAEEncoder ,
563+ }
564+ if components is None :
565+ components = ["sd3_text_encoder_1" , "sd3_text_encoder_2" , "sd3_text_encoder_3" , "sd3_dit" , "sd3_vae_decoder" , "sd3_vae_encoder" ]
566+ for component in components :
567+ if component == "sd3_text_encoder_3" :
568+ if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict :
569+ continue
570+ elif component == "sd3_text_encoder_1" :
571+ # Add additional token embeddings to text encoder
572+ token_embeddings = [state_dict ["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight" ]]
573+ for keyword in self .textual_inversion_dict :
574+ _ , embeddings = self .textual_inversion_dict [keyword ]
575+ token_embeddings .append (embeddings .to (dtype = token_embeddings [0 ].dtype ))
576+ token_embeddings = torch .concat (token_embeddings , dim = 0 )
577+ state_dict ["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight" ] = token_embeddings
578+ self .model [component ] = component_dict [component ](vocab_size = token_embeddings .shape [0 ])
579+ self .model [component ].load_state_dict (self .model [component ].state_dict_converter ().from_civitai (state_dict ))
580+ self .model [component ].to (self .torch_dtype ).to (self .device )
581+ else :
582+ self .model [component ] = component_dict [component ]()
583+ self .model [component ].load_state_dict (self .model [component ].state_dict_converter ().from_civitai (state_dict ))
584+ self .model [component ].to (self .torch_dtype ).to (self .device )
585+ self .model_path [component ] = file_path
586+
587+ def load_stable_diffusion_3_t5 (self , state_dict , file_path = "" ):
588+ component = "sd3_text_encoder_3"
589+ model = SD3TextEncoder3 ()
590+ model .load_state_dict (model .state_dict_converter ().from_civitai (state_dict ))
591+ model .to (self .torch_dtype ).to (self .device )
592+ self .model [component ] = model
593+ self .model_path [component ] = file_path
594+
523595 def search_for_embeddings (self , state_dict ):
524596 embeddings = []
525597 for k in state_dict :
@@ -587,6 +659,10 @@ def load_model(self, file_path, components=None, lora_alphas=[]):
587659 self .load_diffusers_vae (state_dict , file_path = file_path )
588660 elif self .is_ExVideo_StableVideoDiffusion (state_dict ):
589661 self .load_ExVideo_StableVideoDiffusion (state_dict , file_path = file_path )
662+ elif self .is_stable_diffusion_3 (state_dict ):
663+ self .load_stable_diffusion_3 (state_dict , components = components , file_path = file_path )
664+ elif self .is_stable_diffusion_3_t5 (state_dict ):
665+ self .load_stable_diffusion_3_t5 (state_dict , file_path = file_path )
590666
591667 def load_models (self , file_path_list , lora_alphas = []):
592668 for file_path in file_path_list :
0 commit comments