1- import torch , os
1+ import torch , os , json
22from safetensors import safe_open
33from typing_extensions import Literal , TypeAlias
44from typing import List
3636
3737from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder , HunyuanDiTT5TextEncoder
3838from .hunyuan_dit import HunyuanDiT
39+ from .kolors_text_encoder import ChatGLMModel
3940
4041
4142preset_models_on_huggingface = {
159160 ("AI-ModelScope/IP-Adapter" , "sdxl_models/image_encoder/model.safetensors" , "models/IpAdapter/stable_diffusion_xl/image_encoder" ),
160161 ("AI-ModelScope/IP-Adapter" , "sdxl_models/ip-adapter_sdxl.bin" , "models/IpAdapter/stable_diffusion_xl" ),
161162 ],
163+ # Kolors
164+ "Kolors" : [
165+ ("Kwai-Kolors/Kolors" , "text_encoder/config.json" , "models/kolors/Kolors/text_encoder" ),
166+ ("Kwai-Kolors/Kolors" , "text_encoder/pytorch_model.bin.index.json" , "models/kolors/Kolors/text_encoder" ),
167+ ("Kwai-Kolors/Kolors" , "text_encoder/pytorch_model-00001-of-00007.bin" , "models/kolors/Kolors/text_encoder" ),
168+ ("Kwai-Kolors/Kolors" , "text_encoder/pytorch_model-00002-of-00007.bin" , "models/kolors/Kolors/text_encoder" ),
169+ ("Kwai-Kolors/Kolors" , "text_encoder/pytorch_model-00003-of-00007.bin" , "models/kolors/Kolors/text_encoder" ),
170+ ("Kwai-Kolors/Kolors" , "text_encoder/pytorch_model-00004-of-00007.bin" , "models/kolors/Kolors/text_encoder" ),
171+ ("Kwai-Kolors/Kolors" , "text_encoder/pytorch_model-00005-of-00007.bin" , "models/kolors/Kolors/text_encoder" ),
172+ ("Kwai-Kolors/Kolors" , "text_encoder/pytorch_model-00006-of-00007.bin" , "models/kolors/Kolors/text_encoder" ),
173+ ("Kwai-Kolors/Kolors" , "text_encoder/pytorch_model-00007-of-00007.bin" , "models/kolors/Kolors/text_encoder" ),
174+ ("Kwai-Kolors/Kolors" , "unet/diffusion_pytorch_model.safetensors" , "models/kolors/Kolors/unet" ),
175+ ("Kwai-Kolors/Kolors" , "vae/diffusion_pytorch_model.safetensors" , "models/kolors/Kolors/vae" ),
176+ ],
162177}
163178Preset_model_id : TypeAlias = Literal [
164179 "HunyuanDiT" ,
184199 "IP-Adapter-SD" ,
185200 "IP-Adapter-SDXL" ,
186201 "StableDiffusion3" ,
187- "StableDiffusion3_without_T5"
202+ "StableDiffusion3_without_T5" ,
203+ "Kolors" ,
188204]
189205Preset_model_website : TypeAlias = Literal [
190206 "HuggingFace" ,
@@ -272,8 +288,7 @@ def is_stable_diffusion(self, state_dict):
272288
273289 def is_controlnet (self , state_dict ):
274290 param_name = "control_model.time_embed.0.weight"
275- param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
276- return param_name in state_dict or param_name_2 in state_dict
291+ return param_name in state_dict
277292
278293 def is_animatediff (self , state_dict ):
279294 param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
@@ -343,6 +358,21 @@ def is_stable_diffusion_3_t5(self, state_dict):
343358 param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
344359 return param_name in state_dict
345360
361+ def is_kolors_text_encoder (self , file_path ):
362+ file_list = os .listdir (file_path )
363+ if "config.json" in file_list :
364+ try :
365+ with open (os .path .join (file_path , "config.json" ), "r" ) as f :
366+ config = json .load (f )
367+ if config .get ("model_type" ) == "chatglm" :
368+ return True
369+ except :
370+ pass
371+ return False
372+
373+ def is_kolors_unet (self , state_dict ):
374+ return "up_blocks.2.resnets.2.time_emb_proj.weight" in state_dict and "encoder_hid_proj.weight" in state_dict
375+
346376 def load_stable_video_diffusion (self , state_dict , components = None , file_path = "" , add_positional_conv = None ):
347377 component_dict = {
348378 "image_encoder" : SVDImageEncoder ,
@@ -532,13 +562,13 @@ def load_diffusers_vae(self, state_dict, file_path=""):
532562 component = "vae_encoder"
533563 model = SDXLVAEEncoder ()
534564 model .load_state_dict (model .state_dict_converter ().from_diffusers (state_dict ))
535- model .to (self . torch_dtype ).to (self .device )
565+ model .to (torch . float32 ).to (self .device )
536566 self .model [component ] = model
537567 self .model_path [component ] = file_path
538568 component = "vae_decoder"
539569 model = SDXLVAEDecoder ()
540570 model .load_state_dict (model .state_dict_converter ().from_diffusers (state_dict ))
541- model .to (self . torch_dtype ).to (self .device )
571+ model .to (torch . float32 ).to (self .device )
542572 self .model [component ] = model
543573 self .model_path [component ] = file_path
544574
@@ -592,6 +622,21 @@ def load_stable_diffusion_3_t5(self, state_dict, file_path=""):
592622 self .model [component ] = model
593623 self .model_path [component ] = file_path
594624
625+ def load_kolors_text_encoder (self , state_dict = None , file_path = "" ):
626+ component = "kolors_text_encoder"
627+ model = ChatGLMModel .from_pretrained (file_path , torch_dtype = self .torch_dtype )
628+ model = model .to (dtype = self .torch_dtype , device = self .device )
629+ self .model [component ] = model
630+ self .model_path [component ] = file_path
631+
632+ def load_kolors_unet (self , state_dict , file_path = "" ):
633+ component = "kolors_unet"
634+ model = SDXLUNet (is_kolors = True )
635+ model .load_state_dict (model .state_dict_converter ().from_diffusers (state_dict ))
636+ model .to (self .torch_dtype ).to (self .device )
637+ self .model [component ] = model
638+ self .model_path [component ] = file_path
639+
595640 def search_for_embeddings (self , state_dict ):
596641 embeddings = []
597642 for k in state_dict :
@@ -607,7 +652,11 @@ def load_textual_inversions(self, folder):
607652
608653 # Load every textual inversion file
609654 for file_name in os .listdir (folder ):
610- if file_name .endswith (".txt" ):
655+ if os .path .isdir (os .path .join (folder , file_name )) or \
656+ not (file_name .endswith (".bin" ) or \
657+ file_name .endswith (".safetensors" ) or \
658+ file_name .endswith (".pth" ) or \
659+ file_name .endswith (".pt" )):
611660 continue
612661 keyword = os .path .splitext (file_name )[0 ]
613662 state_dict = load_state_dict (os .path .join (folder , file_name ))
@@ -620,6 +669,10 @@ def load_textual_inversions(self, folder):
620669 break
621670
622671 def load_model (self , file_path , components = None , lora_alphas = []):
672+ if os .path .isdir (file_path ):
673+ if self .is_kolors_text_encoder (file_path ):
674+ self .load_kolors_text_encoder (file_path = file_path )
675+ return
623676 state_dict = load_state_dict (file_path , torch_dtype = self .torch_dtype )
624677 if self .is_stable_video_diffusion (state_dict ):
625678 self .load_stable_video_diffusion (state_dict , file_path = file_path )
@@ -663,6 +716,8 @@ def load_model(self, file_path, components=None, lora_alphas=[]):
663716 self .load_stable_diffusion_3 (state_dict , components = components , file_path = file_path )
664717 elif self .is_stable_diffusion_3_t5 (state_dict ):
665718 self .load_stable_diffusion_3_t5 (state_dict , file_path = file_path )
719+ elif self .is_kolors_unet (state_dict ):
720+ self .load_kolors_unet (state_dict , file_path = file_path )
666721
667722 def load_models (self , file_path_list , lora_alphas = []):
668723 for file_path in file_path_list :
0 commit comments