Skip to content

Commit 9c6607f

Browse files
authored
support kolors! (#106)
1 parent 2a4709e commit 9c6607f

File tree

20 files changed

+2510
-281
lines changed

20 files changed

+2510
-281
lines changed

README.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ DiffSynth Studio is a Diffusion engine. We have restructured architectures inclu
88
Until now, DiffSynth Studio has supported the following models:
99

1010
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
11+
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
1112
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
1213
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
1314
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
@@ -85,11 +86,13 @@ Generate high-resolution images, by breaking the limitation of diffusion models!
8586

8687
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
8788

88-
|Stable Diffusion|Stable Diffusion XL|
89+
|Model|Example|
8990
|-|-|
90-
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|
91-
|Stable Diffusion 3|Hunyuan-DiT|
92-
|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|
91+
|Stable Diffusion|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|
92+
|Stable Diffusion XL|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|
93+
|Stable Diffusion 3|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|
94+
|Kolors|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/53ef6f41-da11-4701-8665-9f64392607bf)|
95+
|Hunyuan-DiT|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|
9396

9497
### Toon Shading
9598

diffsynth/models/__init__.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import torch, os
1+
import torch, os, json
22
from safetensors import safe_open
33
from typing_extensions import Literal, TypeAlias
44
from typing import List
@@ -36,6 +36,7 @@
3636

3737
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
3838
from .hunyuan_dit import HunyuanDiT
39+
from .kolors_text_encoder import ChatGLMModel
3940

4041

4142
preset_models_on_huggingface = {
@@ -159,6 +160,20 @@
159160
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
160161
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
161162
],
163+
# Kolors
164+
"Kolors": [
165+
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
166+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
167+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
168+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
169+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
170+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
171+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
172+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
173+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
174+
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
175+
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
176+
],
162177
}
163178
Preset_model_id: TypeAlias = Literal[
164179
"HunyuanDiT",
@@ -184,7 +199,8 @@
184199
"IP-Adapter-SD",
185200
"IP-Adapter-SDXL",
186201
"StableDiffusion3",
187-
"StableDiffusion3_without_T5"
202+
"StableDiffusion3_without_T5",
203+
"Kolors",
188204
]
189205
Preset_model_website: TypeAlias = Literal[
190206
"HuggingFace",
@@ -272,8 +288,7 @@ def is_stable_diffusion(self, state_dict):
272288

273289
def is_controlnet(self, state_dict):
274290
param_name = "control_model.time_embed.0.weight"
275-
param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
276-
return param_name in state_dict or param_name_2 in state_dict
291+
return param_name in state_dict
277292

278293
def is_animatediff(self, state_dict):
279294
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
@@ -343,6 +358,21 @@ def is_stable_diffusion_3_t5(self, state_dict):
343358
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
344359
return param_name in state_dict
345360

361+
def is_kolors_text_encoder(self, file_path):
362+
file_list = os.listdir(file_path)
363+
if "config.json" in file_list:
364+
try:
365+
with open(os.path.join(file_path, "config.json"), "r") as f:
366+
config = json.load(f)
367+
if config.get("model_type") == "chatglm":
368+
return True
369+
except:
370+
pass
371+
return False
372+
373+
def is_kolors_unet(self, state_dict):
374+
return "up_blocks.2.resnets.2.time_emb_proj.weight" in state_dict and "encoder_hid_proj.weight" in state_dict
375+
346376
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
347377
component_dict = {
348378
"image_encoder": SVDImageEncoder,
@@ -532,13 +562,13 @@ def load_diffusers_vae(self, state_dict, file_path=""):
532562
component = "vae_encoder"
533563
model = SDXLVAEEncoder()
534564
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
535-
model.to(self.torch_dtype).to(self.device)
565+
model.to(torch.float32).to(self.device)
536566
self.model[component] = model
537567
self.model_path[component] = file_path
538568
component = "vae_decoder"
539569
model = SDXLVAEDecoder()
540570
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
541-
model.to(self.torch_dtype).to(self.device)
571+
model.to(torch.float32).to(self.device)
542572
self.model[component] = model
543573
self.model_path[component] = file_path
544574

@@ -592,6 +622,21 @@ def load_stable_diffusion_3_t5(self, state_dict, file_path=""):
592622
self.model[component] = model
593623
self.model_path[component] = file_path
594624

625+
def load_kolors_text_encoder(self, state_dict=None, file_path=""):
626+
component = "kolors_text_encoder"
627+
model = ChatGLMModel.from_pretrained(file_path, torch_dtype=self.torch_dtype)
628+
model = model.to(dtype=self.torch_dtype, device=self.device)
629+
self.model[component] = model
630+
self.model_path[component] = file_path
631+
632+
def load_kolors_unet(self, state_dict, file_path=""):
633+
component = "kolors_unet"
634+
model = SDXLUNet(is_kolors=True)
635+
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
636+
model.to(self.torch_dtype).to(self.device)
637+
self.model[component] = model
638+
self.model_path[component] = file_path
639+
595640
def search_for_embeddings(self, state_dict):
596641
embeddings = []
597642
for k in state_dict:
@@ -607,7 +652,11 @@ def load_textual_inversions(self, folder):
607652

608653
# Load every textual inversion file
609654
for file_name in os.listdir(folder):
610-
if file_name.endswith(".txt"):
655+
if os.path.isdir(os.path.join(folder, file_name)) or \
656+
not (file_name.endswith(".bin") or \
657+
file_name.endswith(".safetensors") or \
658+
file_name.endswith(".pth") or \
659+
file_name.endswith(".pt")):
611660
continue
612661
keyword = os.path.splitext(file_name)[0]
613662
state_dict = load_state_dict(os.path.join(folder, file_name))
@@ -620,6 +669,10 @@ def load_textual_inversions(self, folder):
620669
break
621670

622671
def load_model(self, file_path, components=None, lora_alphas=[]):
672+
if os.path.isdir(file_path):
673+
if self.is_kolors_text_encoder(file_path):
674+
self.load_kolors_text_encoder(file_path=file_path)
675+
return
623676
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
624677
if self.is_stable_video_diffusion(state_dict):
625678
self.load_stable_video_diffusion(state_dict, file_path=file_path)
@@ -663,6 +716,8 @@ def load_model(self, file_path, components=None, lora_alphas=[]):
663716
self.load_stable_diffusion_3(state_dict, components=components, file_path=file_path)
664717
elif self.is_stable_diffusion_3_t5(state_dict):
665718
self.load_stable_diffusion_3_t5(state_dict, file_path=file_path)
719+
elif self.is_kolors_unet(state_dict):
720+
self.load_kolors_unet(state_dict, file_path=file_path)
666721

667722
def load_models(self, file_path_list, lora_alphas=[]):
668723
for file_path in file_path_list:

0 commit comments

Comments
 (0)