Skip to content

Commit 518c6d6

Browse files
committed
support SD3 textual inversion
1 parent 9920b8d commit 518c6d6

File tree

5 files changed

+70
-13
lines changed

5 files changed

+70
-13
lines changed

diffsynth/models/__init__.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -567,10 +567,22 @@ def load_stable_diffusion_3(self, state_dict, components=None, file_path=""):
567567
if component == "sd3_text_encoder_3":
568568
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict:
569569
continue
570-
self.model[component] = component_dict[component]()
571-
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
572-
self.model[component].to(self.torch_dtype).to(self.device)
573-
self.model_path[component] = file_path
570+
elif component == "sd3_text_encoder_1":
571+
# Add additional token embeddings to text encoder
572+
token_embeddings = [state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"]]
573+
for keyword in self.textual_inversion_dict:
574+
_, embeddings = self.textual_inversion_dict[keyword]
575+
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
576+
token_embeddings = torch.concat(token_embeddings, dim=0)
577+
state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
578+
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
579+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
580+
self.model[component].to(self.torch_dtype).to(self.device)
581+
else:
582+
self.model[component] = component_dict[component]()
583+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
584+
self.model[component].to(self.torch_dtype).to(self.device)
585+
self.model_path[component] = file_path
574586

575587
def load_stable_diffusion_3_t5(self, state_dict, file_path=""):
576588
component = "sd3_text_encoder_3"

diffsynth/models/sd3_text_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66

77
class SD3TextEncoder1(SDTextEncoder):
8-
def __init__(self):
9-
super().__init__()
8+
def __init__(self, vocab_size=49408):
9+
super().__init__(vocab_size=vocab_size)
1010

1111
def forward(self, input_ids, clip_skip=2):
1212
embeds = self.token_embedding(input_ids) + self.position_embeds

diffsynth/prompts/sd3_prompter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(
2020
base_path = os.path.dirname(os.path.dirname(__file__))
2121
tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
2222
super().__init__()
23-
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
23+
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_1_path)
2424
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
2525
self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
2626

@@ -61,17 +61,17 @@ def encode_prompt(
6161
positive=True,
6262
device="cuda"
6363
):
64-
prompt = self.process_prompt(prompt, positive=positive)
64+
prompt, pure_prompt = self.process_prompt(prompt, positive=positive, require_pure_prompt=True)
6565

6666
# CLIP
67-
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer_1, 77, device)
68-
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, text_encoder_2, self.tokenizer_2, 77, device)
67+
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer, 77, device)
68+
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(pure_prompt, text_encoder_2, self.tokenizer_2, 77, device)
6969

7070
# T5
7171
if text_encoder_3 is None:
7272
prompt_emb_3 = torch.zeros((1, 256, 4096), dtype=prompt_emb_1.dtype, device=device)
7373
else:
74-
prompt_emb_3 = self.encode_prompt_using_t5(prompt, text_encoder_3, self.tokenizer_3, 256, device)
74+
prompt_emb_3 = self.encode_prompt_using_t5(pure_prompt, text_encoder_3, self.tokenizer_3, 256, device)
7575
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
7676

7777
# Merge

diffsynth/prompts/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,27 @@ def load_from_model_manager(self, model_manager: ModelManager):
111111
if "beautiful_prompt" in model_manager.model:
112112
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
113113

114-
def process_prompt(self, prompt, positive=True):
114+
def add_textual_inversion_tokens(self, prompt):
115115
for keyword in self.keyword_dict:
116116
if keyword in prompt:
117117
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
118+
return prompt
119+
120+
def del_textual_inversion_tokens(self, prompt):
121+
for keyword in self.keyword_dict:
122+
if keyword in prompt:
123+
prompt = prompt.replace(keyword, "")
124+
return prompt
125+
126+
def process_prompt(self, prompt, positive=True, require_pure_prompt=False):
127+
prompt, pure_prompt = self.add_textual_inversion_tokens(prompt), self.del_textual_inversion_tokens(prompt)
118128
if positive and self.translator is not None:
119129
prompt = self.translator(prompt)
120130
print(f"Your prompt is translated: \"{prompt}\"")
121131
if positive and self.beautiful_prompt is not None:
122132
prompt = self.beautiful_prompt(prompt)
123133
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
124-
return prompt
134+
if require_pure_prompt:
135+
return prompt, pure_prompt
136+
else:
137+
return prompt
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from diffsynth import ModelManager, SD3ImagePipeline, download_models, load_state_dict
2+
import torch
3+
4+
5+
# Download models (automatically)
6+
# `models/stable_diffusion_3/sd3_medium_incl_clips.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/sd3_medium_incl_clips.safetensors)
7+
# `models/textual_inversion/verybadimagenegative_v1.3.pt`: [link](https://civitai.com/api/download/models/25820?type=Model&format=PickleTensor&size=full&fp=fp16)
8+
download_models(["StableDiffusion3_without_T5", "TextualInversion_VeryBadImageNegative_v1.3"])
9+
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
10+
model_manager.load_textual_inversions("models/textual_inversion")
11+
model_manager.load_models(["models/stable_diffusion_3/sd3_medium_incl_clips.safetensors"])
12+
pipe = SD3ImagePipeline.from_model_manager(model_manager)
13+
14+
15+
for seed in range(4):
16+
torch.manual_seed(seed)
17+
image = pipe(
18+
prompt="a girl, highly detailed, absurd res, perfect image",
19+
negative_prompt="verybadimagenegative_v1.3",
20+
cfg_scale=4.5,
21+
num_inference_steps=50, width=1024, height=1024,
22+
)
23+
image.save(f"image_with_textual_inversion_{seed}.jpg")
24+
25+
torch.manual_seed(seed)
26+
image = pipe(
27+
prompt="a girl, highly detailed, absurd res, perfect image",
28+
negative_prompt="",
29+
cfg_scale=4.5,
30+
num_inference_steps=50, width=1024, height=1024,
31+
)
32+
image.save(f"image_without_textual_inversion_{seed}.jpg")

0 commit comments

Comments
 (0)