Skip to content

Commit 7dbd4bc

Browse files
committed
Simplify for now
1 parent 9dd4ee9 commit 7dbd4bc

File tree

4 files changed

+15
-206
lines changed

4 files changed

+15
-206
lines changed

examples/research_projects/anytext/auxiliary_latent_module.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,11 @@ def retrieve_latents(
5252

5353

5454
class AuxiliaryLatentModule(nn.Module):
55-
def __init__(self, font_path, dims=2, glyph_channels=256, position_channels=64, model_channels=256, **kwargs):
55+
def __init__(self, dims, glyph_channels, position_channels, model_channels, **kwargs):
5656
super().__init__()
57-
if font_path is None:
58-
raise ValueError("font_path must be provided!")
59-
self.font = ImageFont.truetype(font_path, 60)
57+
self.font = ImageFont.truetype("./font/Arial_Unicode.ttf", 60)
6058
self.use_fp16 = kwargs.get("use_fp16", False)
6159
self.device = kwargs.get("device", "cpu")
62-
self.scale_factor = 0.18215
6360
self.glyph_block = nn.Sequential(
6461
conv_nd(dims, glyph_channels, 8, 3, padding=1),
6562
nn.SiLU(),
@@ -98,15 +95,8 @@ def __init__(self, font_path, dims=2, glyph_channels=256, position_channels=64,
9895
nn.SiLU(),
9996
)
10097

101-
self.vae = AutoencoderKL.from_pretrained(
102-
"runwayml/stable-diffusion-v1-5",
103-
subfolder="vae",
104-
torch_dtype=torch.float16 if self.use_fp16 else torch.float32,
105-
variant="fp16" if self.use_fp16 else "fp32",
106-
)
98+
self.vae = kwargs.get("vae")
10799
self.vae.eval()
108-
for param in self.vae.parameters():
109-
param.requires_grad = False
110100

111101
self.fuse_block = zero_module(conv_nd(dims, 256 + 64 + 4, model_channels, 3, padding=1))
112102

@@ -257,7 +247,7 @@ def forward(
257247
return guided_hint, hint, info
258248

259249
def encode_first_stage(self, masked_img):
260-
return retrieve_latents(self.vae.encode(masked_img)) * self.scale_factor
250+
return retrieve_latents(self.vae.encode(masked_img)) * self.vae.scale_factor
261251

262252
def arr2tensor(self, arr, bs):
263253
arr = np.transpose(arr, (2, 0, 1))

examples/research_projects/anytext/embedding_manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,6 @@ def encode_text(self, text_info):
156156
if self.emb_type == "ocr":
157157
recog_emb = self.get_recog_emb(gline_list)
158158
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
159-
elif self.emb_type == "vit":
160-
enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
161-
elif self.emb_type == "conv":
162-
enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
163159
if self.add_pos:
164160
enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
165161
enc_glyph = enc_glyph + enc_pos

examples/research_projects/anytext/pipeline_anytext.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,10 @@ def __init__(
218218
feature_extractor: CLIPImageProcessor,
219219
image_encoder: CLIPVisionModelWithProjection = None,
220220
requires_safety_checker: bool = True,
221-
font_path: str = None,
222221
):
223222
super().__init__()
224-
self.text_embedding_module = TextEmbeddingModule(text_encoder, tokenizer)
225-
self.auxiliary_latent_module = AuxiliaryLatentModule(font_path)
223+
self.text_embedding_module = TextEmbeddingModule(use_fp16=unet.dtype == torch.float16)
224+
self.auxiliary_latent_module = AuxiliaryLatentModule(vae=vae, use_fp16=unet.dtype == torch.float16)
226225

227226
if safety_checker is None and requires_safety_checker:
228227
logger.warning(
@@ -1228,16 +1227,10 @@ def __call__(
12281227
)
12291228
prompt_embeds, negative_prompt_embeds = self.text_embedding_module(
12301229
prompt,
1231-
device,
1232-
num_images_per_prompt,
1233-
self.do_classifier_free_guidance,
1234-
hint,
12351230
text_info,
12361231
negative_prompt,
12371232
prompt_embeds=prompt_embeds,
12381233
negative_prompt_embeds=negative_prompt_embeds,
1239-
lora_scale=text_encoder_lora_scale,
1240-
clip_skip=self.clip_skip,
12411234
)
12421235
# 5. Prepare timesteps
12431236
timesteps, num_inference_steps = retrieve_timesteps(

examples/research_projects/anytext/text_embedding_module.py

Lines changed: 9 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@
3131

3232

3333
class TextEmbeddingModule(nn.Module):
34-
def __init__(self, font_path, device, use_fp16):
34+
def __init__(self, use_fp16):
3535
super().__init__()
36-
self.device = device
36+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
3737
# TODO: Learn if the recommended font file is free to use
38-
self.font = ImageFont.truetype(font_path, 60)
38+
self.font = ImageFont.truetype("./font/Arial_Unicode.ttf", 60)
3939
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device)
4040
self.embedding_manager_config = {
4141
"valid": True,
@@ -49,12 +49,12 @@ def __init__(self, font_path, device, use_fp16):
4949
# TODO: Understand the reason of param.requires_grad = True
5050
for param in self.embedding_manager.embedding_parameters():
5151
param.requires_grad = True
52-
rec_model_dir = "./ocr_weights/ppv3_rec.pth"
52+
rec_model_dir = "./ocr/ppv3_rec.pth"
5353
self.text_predictor = create_predictor(rec_model_dir).eval()
5454
args = {}
5555
args["rec_image_shape"] = "3, 48, 320"
5656
args["rec_batch_num"] = 6
57-
args["rec_char_dict_path"] = "./ocr_recog/ppocr_keys_v1.txt"
57+
args["rec_char_dict_path"] = "./ocr/ppocr_keys_v1.txt"
5858
args["use_fp16"] = use_fp16
5959
self.cn_recognizer = TextRecognizer(args, self.text_predictor)
6060
for param in self.text_predictor.parameters():
@@ -65,185 +65,15 @@ def __init__(self, font_path, device, use_fp16):
6565
def forward(
6666
self,
6767
prompt,
68-
device,
69-
num_images_per_prompt,
70-
do_classifier_free_guidance,
71-
hint,
7268
text_info,
7369
negative_prompt=None,
7470
prompt_embeds: Optional[torch.Tensor] = None,
7571
negative_prompt_embeds: Optional[torch.Tensor] = None,
76-
lora_scale: Optional[float] = None,
77-
clip_skip: Optional[int] = None,
7872
):
79-
# TODO: Convert `get_learned_conditioning` functions to `diffusers`' format
80-
prompt_embeds = self.get_learned_conditioning(
81-
{"c_concat": [hint], "c_crossattn": [[prompt] * num_images_per_prompt], "text_info": text_info}
82-
)
83-
negative_prompt_embeds = self.get_learned_conditioning(
84-
{"c_concat": [hint], "c_crossattn": [[negative_prompt] * num_images_per_prompt], "text_info": text_info}
85-
)
73+
self.embedding_manager.encode_text(text_info)
74+
prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager)
8675

87-
# set lora scale so that monkey patched LoRA
88-
# function of text encoder can correctly access it
89-
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
90-
self._lora_scale = lora_scale
91-
92-
# dynamically adjust the LoRA scale
93-
if not USE_PEFT_BACKEND:
94-
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
95-
else:
96-
scale_lora_layers(self.text_encoder, lora_scale)
97-
98-
if prompt is not None and isinstance(prompt, str):
99-
batch_size = 1
100-
elif prompt is not None and isinstance(prompt, list):
101-
batch_size = len(prompt)
102-
else:
103-
batch_size = prompt_embeds.shape[0]
104-
105-
if prompt_embeds is None:
106-
# textual inversion: process multi-vector tokens if necessary
107-
if isinstance(self, TextualInversionLoaderMixin):
108-
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
109-
110-
text_inputs = self.tokenizer(
111-
prompt,
112-
padding="max_length",
113-
max_length=self.tokenizer.model_max_length,
114-
truncation=True,
115-
return_tensors="pt",
116-
)
117-
text_input_ids = text_inputs.input_ids
118-
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
119-
120-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
121-
text_input_ids, untruncated_ids
122-
):
123-
removed_text = self.tokenizer.batch_decode(
124-
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
125-
)
126-
logger.warning(
127-
"The following part of your input was truncated because CLIP can only handle sequences up to"
128-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
129-
)
130-
131-
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
132-
attention_mask = text_inputs.attention_mask.to(device)
133-
else:
134-
attention_mask = None
135-
136-
if clip_skip is None:
137-
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
138-
prompt_embeds = prompt_embeds[0]
139-
else:
140-
prompt_embeds = self.text_encoder(
141-
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
142-
)
143-
# Access the `hidden_states` first, that contains a tuple of
144-
# all the hidden states from the encoder layers. Then index into
145-
# the tuple to access the hidden states from the desired layer.
146-
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
147-
# We also need to apply the final LayerNorm here to not mess with the
148-
# representations. The `last_hidden_states` that we typically use for
149-
# obtaining the final prompt representations passes through the LayerNorm
150-
# layer.
151-
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
152-
153-
if self.text_encoder is not None:
154-
prompt_embeds_dtype = self.text_encoder.dtype
155-
elif self.unet is not None:
156-
prompt_embeds_dtype = self.unet.dtype
157-
else:
158-
prompt_embeds_dtype = prompt_embeds.dtype
159-
160-
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
161-
162-
bs_embed, seq_len, _ = prompt_embeds.shape
163-
# duplicate text embeddings for each generation per prompt, using mps friendly method
164-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
165-
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
166-
167-
# get unconditional embeddings for classifier free guidance
168-
if do_classifier_free_guidance and negative_prompt_embeds is None:
169-
uncond_tokens: List[str]
170-
if negative_prompt is None:
171-
uncond_tokens = [""] * batch_size
172-
elif prompt is not None and type(prompt) is not type(negative_prompt):
173-
raise TypeError(
174-
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
175-
f" {type(prompt)}."
176-
)
177-
elif isinstance(negative_prompt, str):
178-
uncond_tokens = [negative_prompt]
179-
elif batch_size != len(negative_prompt):
180-
raise ValueError(
181-
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
182-
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
183-
" the batch size of `prompt`."
184-
)
185-
else:
186-
uncond_tokens = negative_prompt
187-
188-
# textual inversion: process multi-vector tokens if necessary
189-
if isinstance(self, TextualInversionLoaderMixin):
190-
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
191-
192-
max_length = prompt_embeds.shape[1]
193-
uncond_input = self.tokenizer(
194-
uncond_tokens,
195-
padding="max_length",
196-
max_length=max_length,
197-
truncation=True,
198-
return_tensors="pt",
199-
)
200-
201-
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
202-
attention_mask = uncond_input.attention_mask.to(device)
203-
else:
204-
attention_mask = None
205-
206-
negative_prompt_embeds = self.text_encoder(
207-
uncond_input.input_ids.to(device),
208-
attention_mask=attention_mask,
209-
)
210-
negative_prompt_embeds = negative_prompt_embeds[0]
211-
212-
if do_classifier_free_guidance:
213-
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
214-
seq_len = negative_prompt_embeds.shape[1]
215-
216-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
217-
218-
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
219-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
220-
221-
if self.text_encoder is not None:
222-
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
223-
# Retrieve the original scale by scaling back the LoRA layers
224-
unscale_lora_layers(self.text_encoder, lora_scale)
76+
self.embedding_manager.encode_text(text_info)
77+
negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode([negative_prompt], embedding_manager=self.embedding_manager)
22578

22679
return prompt_embeds, negative_prompt_embeds
227-
228-
def get_learned_conditioning(self, c):
229-
if hasattr(self.frozen_CLIP_embedder_t3, "encode") and callable(self.frozen_CLIP_embedder_t3.encode):
230-
if self.embedding_manager is not None and c["text_info"] is not None:
231-
self.embedding_manager.encode_text(c["text_info"])
232-
if isinstance(c, dict):
233-
cond_txt = c["c_crossattn"][0]
234-
else:
235-
cond_txt = c
236-
if self.embedding_manager is not None:
237-
cond_txt = self.frozen_CLIP_embedder_t3.encode(cond_txt, embedding_manager=self.embedding_manager)
238-
else:
239-
cond_txt = self.frozen_CLIP_embedder_t3.encode(cond_txt)
240-
if isinstance(c, dict):
241-
c["c_crossattn"][0] = cond_txt
242-
else:
243-
c = cond_txt
244-
if isinstance(c, DiagonalGaussianDistribution):
245-
c = c.mode()
246-
else:
247-
c = self.frozen_CLIP_embedder_t3(c)
248-
249-
return c

0 commit comments

Comments
 (0)