3636 "width" : 512 ,
3737 }
3838 },
39+ "Kolors" : {
40+ "model_folder" : "models/kolors" ,
41+ "pipeline_class" : SDXLImagePipeline ,
42+ "fixed_parameters" : {}
43+ },
3944 "HunyuanDiT" : {
4045 "model_folder" : "models/HunyuanDiT" ,
4146 "pipeline_class" : HunyuanDiTImagePipeline ,
5055def load_model_list (model_type ):
5156 folder = config [model_type ]["model_folder" ]
5257 file_list = [i for i in os .listdir (folder ) if i .endswith (".safetensors" )]
53- if model_type == "HunyuanDiT" :
58+ if model_type in [ "HunyuanDiT" , "Kolors" ] :
5459 file_list += [i for i in os .listdir (folder ) if os .path .isdir (os .path .join (folder , i ))]
5560 file_list = sorted (file_list )
5661 return file_list
@@ -74,6 +79,12 @@ def load_model(model_type, model_path):
7479 os .path .join (model_path , "model/pytorch_model_ema.pt" ),
7580 os .path .join (model_path , "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin" ),
7681 ])
82+ elif model_type == "Kolors" :
83+ model_manager .load_models ([
84+ os .path .join (model_path , "text_encoder" ),
85+ os .path .join (model_path , "unet/diffusion_pytorch_model.safetensors" ),
86+ os .path .join (model_path , "vae/diffusion_pytorch_model.safetensors" ),
87+ ])
7788 else :
7889 model_manager .load_model (model_path )
7990 pipeline = config [model_type ]["pipeline_class" ].from_model_manager (model_manager )
0 commit comments