Skip to content

Commit 40c6da8

Browse files
authored
Merge pull request #132 from modelscope/Artiprocher-rebuild
rebuild base modules
2 parents 9dfb7c1 + 3981b80 commit 40c6da8

File tree

77 files changed

+3254
-3595
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+3254
-3595
lines changed

diffsynth/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .data import *
22
from .models import *
3-
from .prompts import *
3+
from .prompters import *
44
from .schedulers import *
55
from .pipelines import *
66
from .controlnets import *

diffsynth/configs/model_config.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
from typing_extensions import Literal, TypeAlias
2+
3+
from ..models.sd_text_encoder import SDTextEncoder
4+
from ..models.sd_unet import SDUNet
5+
from ..models.sd_vae_encoder import SDVAEEncoder
6+
from ..models.sd_vae_decoder import SDVAEDecoder
7+
8+
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
9+
from ..models.sdxl_unet import SDXLUNet
10+
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
11+
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
12+
13+
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
14+
from ..models.sd3_dit import SD3DiT
15+
from ..models.sd3_vae_decoder import SD3VAEDecoder
16+
from ..models.sd3_vae_encoder import SD3VAEEncoder
17+
18+
from ..models.sd_controlnet import SDControlNet
19+
20+
from ..models.sd_motion import SDMotionModel
21+
from ..models.sdxl_motion import SDXLMotionModel
22+
23+
from ..models.svd_image_encoder import SVDImageEncoder
24+
from ..models.svd_unet import SVDUNet
25+
from ..models.svd_vae_decoder import SVDVAEDecoder
26+
from ..models.svd_vae_encoder import SVDVAEEncoder
27+
28+
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
29+
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
30+
31+
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
32+
from ..models.hunyuan_dit import HunyuanDiT
33+
34+
35+
36+
model_loader_configs = [
37+
# These configs are provided for detecting model type automatically.
38+
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
39+
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
40+
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
41+
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
42+
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
43+
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
44+
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
45+
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
46+
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
47+
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
48+
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
49+
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
50+
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
51+
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
52+
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
53+
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
54+
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
55+
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
56+
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
57+
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
58+
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
59+
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
60+
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
61+
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
62+
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
63+
]
64+
huggingface_model_loader_configs = [
65+
# These configs are provided for detecting model type automatically.
66+
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name)
67+
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder"),
68+
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator"),
69+
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt"),
70+
]
71+
patch_model_loader_configs = [
72+
# These configs are provided for detecting model type automatically.
73+
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
74+
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
75+
]
76+
77+
preset_models_on_huggingface = {
78+
"HunyuanDiT": [
79+
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
80+
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
81+
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
82+
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
83+
],
84+
"stable-video-diffusion-img2vid-xt": [
85+
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
86+
],
87+
"ExVideo-SVD-128f-v1": [
88+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
89+
],
90+
}
91+
preset_models_on_modelscope = {
92+
# Hunyuan DiT
93+
"HunyuanDiT": [
94+
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
95+
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
96+
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
97+
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
98+
],
99+
# Stable Video Diffusion
100+
"stable-video-diffusion-img2vid-xt": [
101+
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
102+
],
103+
# ExVideo
104+
"ExVideo-SVD-128f-v1": [
105+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
106+
],
107+
# Stable Diffusion
108+
"StableDiffusion_v15": [
109+
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
110+
],
111+
"DreamShaper_8": [
112+
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
113+
],
114+
"AingDiffusion_v12": [
115+
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
116+
],
117+
"Flat2DAnimerge_v45Sharp": [
118+
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
119+
],
120+
# Textual Inversion
121+
"TextualInversion_VeryBadImageNegative_v1.3": [
122+
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
123+
],
124+
# Stable Diffusion XL
125+
"StableDiffusionXL_v1": [
126+
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
127+
],
128+
"BluePencilXL_v200": [
129+
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
130+
],
131+
"StableDiffusionXL_Turbo": [
132+
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
133+
],
134+
# Stable Diffusion 3
135+
"StableDiffusion3": [
136+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
137+
],
138+
"StableDiffusion3_without_T5": [
139+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
140+
],
141+
# ControlNet
142+
"ControlNet_v11f1p_sd15_depth": [
143+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
144+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
145+
],
146+
"ControlNet_v11p_sd15_softedge": [
147+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
148+
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
149+
],
150+
"ControlNet_v11f1e_sd15_tile": [
151+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
152+
],
153+
"ControlNet_v11p_sd15_lineart": [
154+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
155+
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
156+
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
157+
],
158+
# AnimateDiff
159+
"AnimateDiff_v2": [
160+
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
161+
],
162+
"AnimateDiff_xl_beta": [
163+
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
164+
],
165+
# RIFE
166+
"RIFE": [
167+
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
168+
],
169+
# Beautiful Prompt
170+
"BeautifulPrompt": [
171+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
172+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
173+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
174+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
175+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
176+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
177+
],
178+
# Translator
179+
"opus-mt-zh-en": [
180+
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
181+
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
182+
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
183+
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
184+
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
185+
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
186+
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
187+
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
188+
],
189+
# IP-Adapter
190+
"IP-Adapter-SD": [
191+
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
192+
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
193+
],
194+
"IP-Adapter-SDXL": [
195+
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
196+
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
197+
],
198+
# Kolors
199+
"Kolors": [
200+
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
201+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
202+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
203+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
204+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
205+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
206+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
207+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
208+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
209+
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
210+
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
211+
],
212+
"SDXL-vae-fp16-fix": [
213+
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
214+
],
215+
}
216+
Preset_model_id: TypeAlias = Literal[
217+
"HunyuanDiT",
218+
"stable-video-diffusion-img2vid-xt",
219+
"ExVideo-SVD-128f-v1",
220+
"StableDiffusion_v15",
221+
"DreamShaper_8",
222+
"AingDiffusion_v12",
223+
"Flat2DAnimerge_v45Sharp",
224+
"TextualInversion_VeryBadImageNegative_v1.3",
225+
"StableDiffusionXL_v1",
226+
"BluePencilXL_v200",
227+
"StableDiffusionXL_Turbo",
228+
"ControlNet_v11f1p_sd15_depth",
229+
"ControlNet_v11p_sd15_softedge",
230+
"ControlNet_v11f1e_sd15_tile",
231+
"ControlNet_v11p_sd15_lineart",
232+
"AnimateDiff_v2",
233+
"AnimateDiff_xl_beta",
234+
"RIFE",
235+
"BeautifulPrompt",
236+
"opus-mt-zh-en",
237+
"IP-Adapter-SD",
238+
"IP-Adapter-SDXL",
239+
"StableDiffusion3",
240+
"StableDiffusion3_without_T5",
241+
"Kolors",
242+
"SDXL-vae-fp16-fix",
243+
]
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch, os
2+
from torchvision import transforms
3+
import pandas as pd
4+
from PIL import Image
5+
6+
7+
8+
class TextImageDataset(torch.utils.data.Dataset):
9+
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
10+
self.steps_per_epoch = steps_per_epoch
11+
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
12+
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
13+
self.text = metadata["text"].to_list()
14+
self.image_processor = transforms.Compose(
15+
[
16+
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
17+
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
18+
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
19+
transforms.ToTensor(),
20+
transforms.Normalize([0.5], [0.5]),
21+
]
22+
)
23+
24+
25+
def __getitem__(self, index):
26+
data_id = torch.randint(0, len(self.path), (1,))[0]
27+
data_id = (data_id + index) % len(self.path) # For fixed seed.
28+
text = self.text[data_id]
29+
image = Image.open(self.path[data_id]).convert("RGB")
30+
image = self.image_processor(image)
31+
return {"text": text, "image": image}
32+
33+
34+
def __len__(self):
35+
return self.steps_per_epoch

diffsynth/extensions/RIFE/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def forward(self, x, scale_list=[4, 2, 1], training=False):
9999
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
100100
return flow_list, mask_list[2], merged
101101

102-
def state_dict_converter(self):
102+
@staticmethod
103+
def state_dict_converter():
103104
return IFNetStateDictConverter()
104105

105106

0 commit comments

Comments
 (0)