Skip to content

Commit 9dd4ee9

Browse files
committed
Converting TextEmbeddingModule to ordinary encode_prompt() function
1 parent 187473d commit 9dd4ee9

File tree

2 files changed

+178
-12
lines changed

2 files changed

+178
-12
lines changed

examples/research_projects/anytext/pipeline_anytext.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,10 +1145,6 @@ def __call__(
11451145
)
11461146
guess_mode = guess_mode or global_pool_conditions
11471147

1148-
# 3. Encode input prompt
1149-
text_encoder_lora_scale = (
1150-
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1151-
)
11521148
prompt, texts = self.modify_prompt(prompt)
11531149

11541150
# For classifier free guidance, we need to do two forward passes.
@@ -1226,14 +1222,22 @@ def __call__(
12261222
else:
12271223
assert False
12281224

1225+
# 3. Encode input prompt
1226+
text_encoder_lora_scale = (
1227+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1228+
)
12291229
prompt_embeds, negative_prompt_embeds = self.text_embedding_module(
12301230
prompt,
12311231
device,
12321232
num_images_per_prompt,
12331233
self.do_classifier_free_guidance,
12341234
hint,
1235-
negative_prompt,
12361235
text_info,
1236+
negative_prompt,
1237+
prompt_embeds=prompt_embeds,
1238+
negative_prompt_embeds=negative_prompt_embeds,
1239+
lora_scale=text_encoder_lora_scale,
1240+
clip_skip=self.clip_skip,
12371241
)
12381242
# 5. Prepare timesteps
12391243
timesteps, num_inference_steps = retrieve_timesteps(

examples/research_projects/anytext/text_embedding_module.py

Lines changed: 169 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,24 @@
33
# text -> tokenizer ->
44

55

6+
from typing import List, Optional
7+
68
import torch
79
from PIL import ImageFont
810
from torch import nn
911

12+
from diffusers.loaders import (
13+
StableDiffusionLoraLoaderMixin,
14+
TextualInversionLoaderMixin,
15+
)
1016
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
11-
from diffusers.utils import logging
17+
from diffusers.models.lora import adjust_lora_scale_text_encoder
18+
from diffusers.utils import (
19+
USE_PEFT_BACKEND,
20+
logging,
21+
scale_lora_layers,
22+
unscale_lora_layers,
23+
)
1224

1325
from .embedding_manager import EmbeddingManager
1426
from .frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3
@@ -50,14 +62,167 @@ def __init__(self, font_path, device, use_fp16):
5062
self.embedding_manager.recog = self.cn_recognizer
5163

5264
@torch.no_grad()
53-
def forward(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, hint, n_prompt, text_info):
65+
def forward(
66+
self,
67+
prompt,
68+
device,
69+
num_images_per_prompt,
70+
do_classifier_free_guidance,
71+
hint,
72+
text_info,
73+
negative_prompt=None,
74+
prompt_embeds: Optional[torch.Tensor] = None,
75+
negative_prompt_embeds: Optional[torch.Tensor] = None,
76+
lora_scale: Optional[float] = None,
77+
clip_skip: Optional[int] = None,
78+
):
79+
# TODO: Convert `get_learned_conditioning` functions to `diffusers`' format
5480
prompt_embeds = self.get_learned_conditioning(
55-
{"c_concat": [hint], "c_crossattn": [[prompt] * len(prompt)], "text_info": text_info}
81+
{"c_concat": [hint], "c_crossattn": [[prompt] * num_images_per_prompt], "text_info": text_info}
5682
)
5783
negative_prompt_embeds = self.get_learned_conditioning(
58-
{"c_concat": [hint], "c_crossattn": [[n_prompt] * len(prompt)], "text_info": text_info}
84+
{"c_concat": [hint], "c_crossattn": [[negative_prompt] * num_images_per_prompt], "text_info": text_info}
5985
)
6086

87+
# set lora scale so that monkey patched LoRA
88+
# function of text encoder can correctly access it
89+
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
90+
self._lora_scale = lora_scale
91+
92+
# dynamically adjust the LoRA scale
93+
if not USE_PEFT_BACKEND:
94+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
95+
else:
96+
scale_lora_layers(self.text_encoder, lora_scale)
97+
98+
if prompt is not None and isinstance(prompt, str):
99+
batch_size = 1
100+
elif prompt is not None and isinstance(prompt, list):
101+
batch_size = len(prompt)
102+
else:
103+
batch_size = prompt_embeds.shape[0]
104+
105+
if prompt_embeds is None:
106+
# textual inversion: process multi-vector tokens if necessary
107+
if isinstance(self, TextualInversionLoaderMixin):
108+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
109+
110+
text_inputs = self.tokenizer(
111+
prompt,
112+
padding="max_length",
113+
max_length=self.tokenizer.model_max_length,
114+
truncation=True,
115+
return_tensors="pt",
116+
)
117+
text_input_ids = text_inputs.input_ids
118+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
119+
120+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
121+
text_input_ids, untruncated_ids
122+
):
123+
removed_text = self.tokenizer.batch_decode(
124+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
125+
)
126+
logger.warning(
127+
"The following part of your input was truncated because CLIP can only handle sequences up to"
128+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
129+
)
130+
131+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
132+
attention_mask = text_inputs.attention_mask.to(device)
133+
else:
134+
attention_mask = None
135+
136+
if clip_skip is None:
137+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
138+
prompt_embeds = prompt_embeds[0]
139+
else:
140+
prompt_embeds = self.text_encoder(
141+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
142+
)
143+
# Access the `hidden_states` first, that contains a tuple of
144+
# all the hidden states from the encoder layers. Then index into
145+
# the tuple to access the hidden states from the desired layer.
146+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
147+
# We also need to apply the final LayerNorm here to not mess with the
148+
# representations. The `last_hidden_states` that we typically use for
149+
# obtaining the final prompt representations passes through the LayerNorm
150+
# layer.
151+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
152+
153+
if self.text_encoder is not None:
154+
prompt_embeds_dtype = self.text_encoder.dtype
155+
elif self.unet is not None:
156+
prompt_embeds_dtype = self.unet.dtype
157+
else:
158+
prompt_embeds_dtype = prompt_embeds.dtype
159+
160+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
161+
162+
bs_embed, seq_len, _ = prompt_embeds.shape
163+
# duplicate text embeddings for each generation per prompt, using mps friendly method
164+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
165+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
166+
167+
# get unconditional embeddings for classifier free guidance
168+
if do_classifier_free_guidance and negative_prompt_embeds is None:
169+
uncond_tokens: List[str]
170+
if negative_prompt is None:
171+
uncond_tokens = [""] * batch_size
172+
elif prompt is not None and type(prompt) is not type(negative_prompt):
173+
raise TypeError(
174+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
175+
f" {type(prompt)}."
176+
)
177+
elif isinstance(negative_prompt, str):
178+
uncond_tokens = [negative_prompt]
179+
elif batch_size != len(negative_prompt):
180+
raise ValueError(
181+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
182+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
183+
" the batch size of `prompt`."
184+
)
185+
else:
186+
uncond_tokens = negative_prompt
187+
188+
# textual inversion: process multi-vector tokens if necessary
189+
if isinstance(self, TextualInversionLoaderMixin):
190+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
191+
192+
max_length = prompt_embeds.shape[1]
193+
uncond_input = self.tokenizer(
194+
uncond_tokens,
195+
padding="max_length",
196+
max_length=max_length,
197+
truncation=True,
198+
return_tensors="pt",
199+
)
200+
201+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
202+
attention_mask = uncond_input.attention_mask.to(device)
203+
else:
204+
attention_mask = None
205+
206+
negative_prompt_embeds = self.text_encoder(
207+
uncond_input.input_ids.to(device),
208+
attention_mask=attention_mask,
209+
)
210+
negative_prompt_embeds = negative_prompt_embeds[0]
211+
212+
if do_classifier_free_guidance:
213+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
214+
seq_len = negative_prompt_embeds.shape[1]
215+
216+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
217+
218+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
219+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
220+
221+
if self.text_encoder is not None:
222+
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
223+
# Retrieve the original scale by scaling back the LoRA layers
224+
unscale_lora_layers(self.text_encoder, lora_scale)
225+
61226
return prompt_embeds, negative_prompt_embeds
62227

63228
def get_learned_conditioning(self, c):
@@ -82,6 +247,3 @@ def get_learned_conditioning(self, c):
82247
c = self.frozen_CLIP_embedder_t3(c)
83248

84249
return c
85-
86-
def get_unconditional_conditioning(self, N):
87-
return self.get_learned_conditioning({"c_crossattn": [[""] * N], "text_info": None})

0 commit comments

Comments
 (0)