Skip to content

Commit 52fb0b4

Browse files
committed
make style
1 parent cc0c6e5 commit 52fb0b4

File tree

2 files changed

+40
-35
lines changed

2 files changed

+40
-35
lines changed

examples/research_projects/anytext/auxiliary_latent_module.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,10 @@ def forward(
196196
np_hint = np.sum(pre_pos, axis=0).clip(0, 1)
197197
# prepare info dict
198198
info = {}
199-
info['glyphs'] = []
200-
info['gly_line'] = []
201-
info['positions'] = []
202-
info['n_lines'] = [len(texts)]*len(prompt)
199+
info["glyphs"] = []
200+
info["gly_line"] = []
201+
info["positions"] = []
202+
info["n_lines"] = [len(texts)] * len(prompt)
203203
for i in range(len(texts)):
204204
text = texts[i]
205205
if len(text) > max_chars:
@@ -209,40 +209,47 @@ def forward(
209209
gly_scale = 2
210210
if pre_pos[i].mean() != 0:
211211
gly_line = self.draw_glyph(self.font, text)
212-
glyphs = self.draw_glyph2(self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False)
212+
glyphs = self.draw_glyph2(
213+
self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False
214+
)
213215
if revise_pos:
214216
resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0]))
215-
new_pos = cv2.morphologyEx((resize_gly*255).astype(np.uint8), cv2.MORPH_CLOSE, kernel=np.ones((resize_gly.shape[0]//10, resize_gly.shape[1]//10), dtype=np.uint8), iterations=1)
217+
new_pos = cv2.morphologyEx(
218+
(resize_gly * 255).astype(np.uint8),
219+
cv2.MORPH_CLOSE,
220+
kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8),
221+
iterations=1,
222+
)
216223
new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos
217224
contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
218225
if len(contours) != 1:
219-
str_warning = f'Fail to revise position {i} to bounding rect, remain position unchanged...'
226+
str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..."
220227
logger.warning(str_warning)
221228
else:
222229
rect = cv2.minAreaRect(contours[0])
223230
poly = np.int0(cv2.boxPoints(rect))
224-
pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.
231+
pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0
225232
else:
226-
glyphs = np.zeros((h*gly_scale, w*gly_scale, 1))
233+
glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
227234
gly_line = np.zeros((80, 512, 1))
228235
pos = pre_pos[i]
229-
info['glyphs'] += [self.arr2tensor(glyphs, len(prompt))]
230-
info['gly_line'] += [self.arr2tensor(gly_line, len(prompt))]
231-
info['positions'] += [self.arr2tensor(pos, len(prompt))]
236+
info["glyphs"] += [self.arr2tensor(glyphs, len(prompt))]
237+
info["gly_line"] += [self.arr2tensor(gly_line, len(prompt))]
238+
info["positions"] += [self.arr2tensor(pos, len(prompt))]
232239
# get masked_x
233-
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0)*(1-np_hint)
240+
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
234241
masked_img = np.transpose(masked_img, (2, 0, 1))
235242
masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
236243
if self.use_fp16:
237244
masked_img = masked_img.half()
238245
masked_x = self.encode_first_stage(masked_img[None, ...]).detach()
239246
if self.use_fp16:
240247
masked_x = masked_x.half()
241-
info['masked_x'] = torch.cat([masked_x for _ in range(len(prompt))], dim=0)
248+
info["masked_x"] = torch.cat([masked_x for _ in range(len(prompt))], dim=0)
242249
hint = self.arr2tensor(np_hint, len(prompt))
243250

244-
glyphs = torch.cat(info['glyphs'], dim=1).sum(dim=1, keepdim=True)
245-
positions = torch.cat(info['positions'], dim=1).sum(dim=1, keepdim=True)
251+
glyphs = torch.cat(info["glyphs"], dim=1).sum(dim=1, keepdim=True)
252+
positions = torch.cat(info["positions"], dim=1).sum(dim=1, keepdim=True)
246253
enc_glyph = self.glyph_block(glyphs, emb, context)
247254
enc_pos = self.position_block(positions, emb, context)
248255
guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, masked_x], dim=1))

examples/research_projects/anytext/text_embedding_module.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,13 @@
22
# +> Token Replacement -> FrozenCLIPEmbedderT3
33
# text -> tokenizer ->
44

5-
from typing import List, Optional
65

7-
import cv2
8-
import numpy as np
96
import torch
10-
from easydict import EasyDict as edict
11-
from PIL import Image, ImageDraw, ImageFont
7+
from PIL import ImageFont
128
from torch import nn
139

14-
from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
15-
from diffusers.models.lora import adjust_lora_scale_text_encoder
16-
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
1710
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
11+
from diffusers.utils import logging
1812

1913
from .embedding_manager import EmbeddingManager
2014
from .frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3
@@ -46,36 +40,40 @@ def __init__(self, font_path, device, use_fp16):
4640
rec_model_dir = "./ocr_weights/ppv3_rec.pth"
4741
self.text_predictor = create_predictor(rec_model_dir).eval()
4842
args = {}
49-
args['rec_image_shape'] = "3, 48, 320"
50-
args['rec_batch_num'] = 6
51-
args['rec_char_dict_path'] = "./ocr_recog/ppocr_keys_v1.txt"
52-
args['use_fp16'] = use_fp16
43+
args["rec_image_shape"] = "3, 48, 320"
44+
args["rec_batch_num"] = 6
45+
args["rec_char_dict_path"] = "./ocr_recog/ppocr_keys_v1.txt"
46+
args["use_fp16"] = use_fp16
5347
self.cn_recognizer = TextRecognizer(args, self.text_predictor)
5448
for param in self.text_predictor.parameters():
5549
param.requires_grad = False
5650
self.embedding_manager.recog = self.cn_recognizer
5751

5852
@torch.no_grad()
5953
def forward(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, hint, n_prompt, text_info):
60-
prompt_embeds = self.get_learned_conditioning({"c_concat": [hint], "c_crossattn": [[prompt] * len(prompt)], "text_info": text_info})
61-
negative_prompt_embeds = self.get_learned_conditioning({"c_concat": [hint], "c_crossattn": [[n_prompt] * len(prompt)], "text_info": text_info})
54+
prompt_embeds = self.get_learned_conditioning(
55+
{"c_concat": [hint], "c_crossattn": [[prompt] * len(prompt)], "text_info": text_info}
56+
)
57+
negative_prompt_embeds = self.get_learned_conditioning(
58+
{"c_concat": [hint], "c_crossattn": [[n_prompt] * len(prompt)], "text_info": text_info}
59+
)
6260

6361
return prompt_embeds, negative_prompt_embeds
6462

6563
def get_learned_conditioning(self, c):
66-
if hasattr(self.frozen_CLIP_embedder_t3, 'encode') and callable(self.frozen_CLIP_embedder_t3.encode):
67-
if self.embedding_manager is not None and c['text_info'] is not None:
68-
self.embedding_manager.encode_text(c['text_info'])
64+
if hasattr(self.frozen_CLIP_embedder_t3, "encode") and callable(self.frozen_CLIP_embedder_t3.encode):
65+
if self.embedding_manager is not None and c["text_info"] is not None:
66+
self.embedding_manager.encode_text(c["text_info"])
6967
if isinstance(c, dict):
70-
cond_txt = c['c_crossattn'][0]
68+
cond_txt = c["c_crossattn"][0]
7169
else:
7270
cond_txt = c
7371
if self.embedding_manager is not None:
7472
cond_txt = self.frozen_CLIP_embedder_t3.encode(cond_txt, embedding_manager=self.embedding_manager)
7573
else:
7674
cond_txt = self.frozen_CLIP_embedder_t3.encode(cond_txt)
7775
if isinstance(c, dict):
78-
c['c_crossattn'][0] = cond_txt
76+
c["c_crossattn"][0] = cond_txt
7977
else:
8078
c = cond_txt
8179
if isinstance(c, DiagonalGaussianDistribution):

0 commit comments

Comments
 (0)