Skip to content

Commit da67ff7

Browse files
committed
Fix: Move glyph rendering to TextEmbeddingModule from AuxiliaryLatentModule
1 parent 1cdbb55 commit da67ff7

File tree

3 files changed

+161
-162
lines changed

3 files changed

+161
-162
lines changed

examples/research_projects/anytext/auxiliary_latent_module.py

Lines changed: 6 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def retrieve_latents(
5353
class AuxiliaryLatentModule(nn.Module):
5454
def __init__(self, dims=2, glyph_channels=1, position_channels=1, model_channels=320, **kwargs):
5555
super().__init__()
56-
self.font = ImageFont.truetype("/home/x/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60)
56+
self.font = ImageFont.truetype("/home/cosmos/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60)
5757
self.use_fp16 = kwargs.get("use_fp16", False)
5858
self.device = kwargs.get("device", "cpu")
5959
self.glyph_block = nn.Sequential(
@@ -104,146 +104,15 @@ def forward(
104104
self,
105105
emb,
106106
context,
107-
mode,
108-
texts,
109-
prompt,
110-
draw_pos,
111-
ori_image,
112-
max_chars=77,
113-
revise_pos=False,
114-
sort_priority=False,
115-
h=512,
116-
w=512,
107+
text_info,
117108
):
118-
if prompt is None and texts is None:
119-
raise ValueError("Prompt or texts must be provided!")
120-
n_lines = len(texts)
121-
if mode == "generate":
122-
edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
123-
elif mode == "edit":
124-
if draw_pos is None or ori_image is None:
125-
raise ValueError("Reference image and position image are needed for text editing!")
126-
if isinstance(ori_image, str):
127-
ori_image = cv2.imread(ori_image)[..., ::-1]
128-
if ori_image is None:
129-
raise ValueError(f"Can't read ori_image image from {ori_image}!")
130-
elif isinstance(ori_image, torch.Tensor):
131-
ori_image = ori_image.cpu().numpy()
132-
else:
133-
if not isinstance(ori_image, np.ndarray):
134-
raise ValueError(f"Unknown format of ori_image: {type(ori_image)}")
135-
edit_image = ori_image.clip(1, 255) # for mask reason
136-
edit_image = self.check_channels(edit_image)
137-
edit_image = self.resize_image(
138-
edit_image, max_length=768
139-
) # make w h multiple of 64, resize if w or h > max_length
140-
h, w = edit_image.shape[:2] # change h, w by input ref_img
141-
# preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
142-
if draw_pos is None:
143-
pos_imgs = np.zeros((w, h, 1))
144-
if isinstance(draw_pos, str):
145-
draw_pos = cv2.imread(draw_pos)[..., ::-1]
146-
if draw_pos is None:
147-
raise ValueError(f"Can't read draw_pos image from {draw_pos}!")
148-
pos_imgs = 255 - draw_pos
149-
elif isinstance(draw_pos, torch.Tensor):
150-
pos_imgs = draw_pos.cpu().numpy()
151-
else:
152-
if not isinstance(draw_pos, np.ndarray):
153-
raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}")
154-
if mode == "edit":
155-
pos_imgs = cv2.resize(pos_imgs, (w, h))
156-
pos_imgs = pos_imgs[..., 0:1]
157-
pos_imgs = cv2.convertScaleAbs(pos_imgs)
158-
_, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
159-
# separate pos_imgs
160-
pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority)
161-
if len(pos_imgs) == 0:
162-
pos_imgs = [np.zeros((h, w, 1))]
163-
if len(pos_imgs) < n_lines:
164-
if n_lines == 1 and texts[0] == " ":
165-
pass # text-to-image without text
166-
else:
167-
raise ValueError(
168-
f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!"
169-
)
170-
elif len(pos_imgs) > n_lines:
171-
str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
172-
logger.warning(str_warning)
173-
# get pre_pos, poly_list, hint that needed for anytext
174-
pre_pos = []
175-
poly_list = []
176-
for input_pos in pos_imgs:
177-
if input_pos.mean() != 0:
178-
input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos
179-
poly, pos_img = self.find_polygon(input_pos)
180-
pre_pos += [pos_img / 255.0]
181-
poly_list += [poly]
182-
else:
183-
pre_pos += [np.zeros((h, w, 1))]
184-
poly_list += [None]
185-
np_hint = np.sum(pre_pos, axis=0).clip(0, 1)
186-
# prepare info dict
187-
info = {}
188-
info["glyphs"] = []
189-
info["gly_line"] = []
190-
info["positions"] = []
191-
info["n_lines"] = [len(texts)] * len(prompt)
192-
for i in range(len(texts)):
193-
text = texts[i]
194-
if len(text) > max_chars:
195-
str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...'
196-
logger.warning(str_warning)
197-
text = text[:max_chars]
198-
gly_scale = 2
199-
if pre_pos[i].mean() != 0:
200-
gly_line = self.draw_glyph(self.font, text)
201-
glyphs = self.draw_glyph2(
202-
self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False
203-
)
204-
if revise_pos:
205-
resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0]))
206-
new_pos = cv2.morphologyEx(
207-
(resize_gly * 255).astype(np.uint8),
208-
cv2.MORPH_CLOSE,
209-
kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8),
210-
iterations=1,
211-
)
212-
new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos
213-
contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
214-
if len(contours) != 1:
215-
str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..."
216-
logger.warning(str_warning)
217-
else:
218-
rect = cv2.minAreaRect(contours[0])
219-
poly = np.int0(cv2.boxPoints(rect))
220-
pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0
221-
else:
222-
glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
223-
gly_line = np.zeros((80, 512, 1))
224-
pos = pre_pos[i]
225-
info["glyphs"] += [self.arr2tensor(glyphs, len(prompt))]
226-
info["gly_line"] += [self.arr2tensor(gly_line, len(prompt))]
227-
info["positions"] += [self.arr2tensor(pos, len(prompt))]
228-
# get masked_x
229-
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
230-
masked_img = np.transpose(masked_img, (2, 0, 1))
231-
masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
232-
if self.use_fp16:
233-
masked_img = masked_img.half()
234-
masked_x = self.encode_first_stage(masked_img[None, ...]).detach()
235-
if self.use_fp16:
236-
masked_x = masked_x.half()
237-
info["masked_x"] = torch.cat([masked_x for _ in range(len(prompt))], dim=0)
238-
hint = self.arr2tensor(np_hint, len(prompt))
239-
240-
glyphs = torch.cat(info["glyphs"], dim=1).sum(dim=1, keepdim=True)
241-
positions = torch.cat(info["positions"], dim=1).sum(dim=1, keepdim=True)
109+
glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True)
110+
positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True)
242111
enc_glyph = self.glyph_block(glyphs, emb, context)
243112
enc_pos = self.position_block(positions, emb, context)
244-
guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, masked_x], dim=1))
113+
guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1))
245114

246-
return guided_hint, hint, info
115+
return guided_hint
247116

248117
def encode_first_stage(self, masked_img):
249118
return retrieve_latents(self.vae.encode(masked_img)) * self.vae.scale_factor

examples/research_projects/anytext/pipeline_anytext.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,20 @@ def __call__(
11141114

11151115
prompt, texts = self.modify_prompt(prompt)
11161116

1117+
# 3. Encode input prompt
1118+
text_encoder_lora_scale = (
1119+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1120+
)
1121+
prompt_embeds, negative_prompt_embeds, text_info = self.text_embedding_module(
1122+
prompt,
1123+
texts,
1124+
negative_prompt,
1125+
num_images_per_prompt,
1126+
mode,
1127+
draw_pos,
1128+
ori_image,
1129+
)
1130+
11171131
# For classifier free guidance, we need to do two forward passes.
11181132
# Here we concatenate the unconditional and text embeddings into a single batch
11191133
# to avoid doing two forward passes
@@ -1151,15 +1165,10 @@ def __call__(
11511165
# guess_mode=guess_mode,
11521166
# )
11531167
# height, width = image.shape[-2:]
1154-
guided_hint, hint, text_info = self.auxiliary_latent_module(
1168+
guided_hint = self.auxiliary_latent_module(
11551169
emb=timestep_cond,
11561170
context=prompt_embeds,
1157-
mode=mode,
1158-
texts=texts,
1159-
prompt=prompt,
1160-
draw_pos=draw_pos,
1161-
ori_image=ori_image,
1162-
img_count=len(prompt),
1171+
text_info=text_info,
11631172
)
11641173
# elif isinstance(controlnet, MultiControlNetModel):
11651174
# images = []
@@ -1189,17 +1198,6 @@ def __call__(
11891198
else:
11901199
assert False
11911200

1192-
# 3. Encode input prompt
1193-
text_encoder_lora_scale = (
1194-
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1195-
)
1196-
prompt_embeds, negative_prompt_embeds = self.text_embedding_module(
1197-
prompt,
1198-
text_info,
1199-
negative_prompt,
1200-
prompt_embeds=prompt_embeds,
1201-
negative_prompt_embeds=negative_prompt_embeds,
1202-
)
12031201
# 5. Prepare timesteps
12041202
timesteps, num_inference_steps = retrieve_timesteps(
12051203
self.scheduler, num_inference_steps, device, timesteps, sigmas

examples/research_projects/anytext/text_embedding_module.py

Lines changed: 139 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from PIL import ImageFont
1212
from recognizer import TextRecognizer, create_predictor
1313
from torch import nn
14+
from torch.nn import functional as F
15+
import numpy as np
16+
import cv2
1417

1518
from diffusers.utils import (
1619
logging,
@@ -25,7 +28,7 @@ def __init__(self, use_fp16):
2528
super().__init__()
2629
self.device = "cuda" if torch.cuda.is_available() else "cpu"
2730
# TODO: Learn if the recommended font file is free to use
28-
self.font = ImageFont.truetype("/home/x/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60)
31+
self.font = ImageFont.truetype("/home/cosmos/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60)
2932
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device)
3033
self.embedding_manager_config = {
3134
"valid": True,
@@ -39,12 +42,12 @@ def __init__(self, use_fp16):
3942
# TODO: Understand the reason of param.requires_grad = True
4043
for param in self.embedding_manager.embedding_parameters():
4144
param.requires_grad = True
42-
rec_model_dir = "/home/x/Documents/gits/AnyText/ocr_weights/ppv3_rec.pth"
45+
rec_model_dir = "/home/cosmos/Documents/gits/AnyText/ocr_weights/ppv3_rec.pth"
4346
self.text_predictor = create_predictor(rec_model_dir).eval()
4447
args = {}
4548
args["rec_image_shape"] = "3, 48, 320"
4649
args["rec_batch_num"] = 6
47-
args["rec_char_dict_path"] = "/home/x/Documents/gits/AnyText/ocr_weights/ppocr_keys_v1.txt"
50+
args["rec_char_dict_path"] = "/home/cosmos/Documents/gits/AnyText/ocr_weights/ppocr_keys_v1.txt"
4851
args["use_fp16"] = use_fp16
4952
self.cn_recognizer = TextRecognizer(args, self.text_predictor)
5053
for param in self.text_predictor.parameters():
@@ -55,11 +58,140 @@ def __init__(self, use_fp16):
5558
def forward(
5659
self,
5760
prompt,
58-
text_info,
59-
negative_prompt=None,
60-
prompt_embeds: Optional[torch.Tensor] = None,
61-
negative_prompt_embeds: Optional[torch.Tensor] = None,
61+
texts,
62+
negative_prompt,
63+
num_images_per_prompt,
64+
mode,
65+
draw_pos,
66+
ori_image,
67+
max_chars=77,
68+
revise_pos=False,
69+
sort_priority=False,
70+
h=512,
71+
w=512,
6272
):
73+
if prompt is None and texts is None:
74+
raise ValueError("Prompt or texts must be provided!")
75+
n_lines = len(texts)
76+
if mode == "generate":
77+
edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
78+
elif mode == "edit":
79+
if draw_pos is None or ori_image is None:
80+
raise ValueError("Reference image and position image are needed for text editing!")
81+
if isinstance(ori_image, str):
82+
ori_image = cv2.imread(ori_image)[..., ::-1]
83+
if ori_image is None:
84+
raise ValueError(f"Can't read ori_image image from {ori_image}!")
85+
elif isinstance(ori_image, torch.Tensor):
86+
ori_image = ori_image.cpu().numpy()
87+
else:
88+
if not isinstance(ori_image, np.ndarray):
89+
raise ValueError(f"Unknown format of ori_image: {type(ori_image)}")
90+
edit_image = ori_image.clip(1, 255) # for mask reason
91+
edit_image = self.check_channels(edit_image)
92+
edit_image = self.resize_image(
93+
edit_image, max_length=768
94+
) # make w h multiple of 64, resize if w or h > max_length
95+
h, w = edit_image.shape[:2] # change h, w by input ref_img
96+
# preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
97+
if draw_pos is None:
98+
pos_imgs = np.zeros((w, h, 1))
99+
if isinstance(draw_pos, str):
100+
draw_pos = cv2.imread(draw_pos)[..., ::-1]
101+
if draw_pos is None:
102+
raise ValueError(f"Can't read draw_pos image from {draw_pos}!")
103+
pos_imgs = 255 - draw_pos
104+
elif isinstance(draw_pos, torch.Tensor):
105+
pos_imgs = draw_pos.cpu().numpy()
106+
else:
107+
if not isinstance(draw_pos, np.ndarray):
108+
raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}")
109+
if mode == "edit":
110+
pos_imgs = cv2.resize(pos_imgs, (w, h))
111+
pos_imgs = pos_imgs[..., 0:1]
112+
pos_imgs = cv2.convertScaleAbs(pos_imgs)
113+
_, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
114+
# separate pos_imgs
115+
pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority)
116+
if len(pos_imgs) == 0:
117+
pos_imgs = [np.zeros((h, w, 1))]
118+
if len(pos_imgs) < n_lines:
119+
if n_lines == 1 and texts[0] == " ":
120+
pass # text-to-image without text
121+
else:
122+
raise ValueError(
123+
f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!"
124+
)
125+
elif len(pos_imgs) > n_lines:
126+
str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
127+
logger.warning(str_warning)
128+
# get pre_pos, poly_list, hint that needed for anytext
129+
pre_pos = []
130+
poly_list = []
131+
for input_pos in pos_imgs:
132+
if input_pos.mean() != 0:
133+
input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos
134+
poly, pos_img = self.find_polygon(input_pos)
135+
pre_pos += [pos_img / 255.0]
136+
poly_list += [poly]
137+
else:
138+
pre_pos += [np.zeros((h, w, 1))]
139+
poly_list += [None]
140+
np_hint = np.sum(pre_pos, axis=0).clip(0, 1)
141+
# prepare info dict
142+
text_info = {}
143+
text_info["glyphs"] = []
144+
text_info["gly_line"] = []
145+
text_info["positions"] = []
146+
text_info["n_lines"] = [len(texts)] * num_images_per_prompt
147+
for i in range(len(texts)):
148+
text = texts[i]
149+
if len(text) > max_chars:
150+
str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...'
151+
logger.warning(str_warning)
152+
text = text[:max_chars]
153+
gly_scale = 2
154+
if pre_pos[i].mean() != 0:
155+
gly_line = self.draw_glyph(self.font, text)
156+
glyphs = self.draw_glyph2(
157+
self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False
158+
)
159+
if revise_pos:
160+
resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0]))
161+
new_pos = cv2.morphologyEx(
162+
(resize_gly * 255).astype(np.uint8),
163+
cv2.MORPH_CLOSE,
164+
kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8),
165+
iterations=1,
166+
)
167+
new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos
168+
contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
169+
if len(contours) != 1:
170+
str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..."
171+
logger.warning(str_warning)
172+
else:
173+
rect = cv2.minAreaRect(contours[0])
174+
poly = np.int0(cv2.boxPoints(rect))
175+
pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0
176+
else:
177+
glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
178+
gly_line = np.zeros((80, 512, 1))
179+
pos = pre_pos[i]
180+
text_info["glyphs"] += [self.arr2tensor(glyphs, len(prompt))]
181+
text_info["gly_line"] += [self.arr2tensor(gly_line, len(prompt))]
182+
text_info["positions"] += [self.arr2tensor(pos, len(prompt))]
183+
# get masked_x
184+
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
185+
masked_img = np.transpose(masked_img, (2, 0, 1))
186+
masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
187+
if self.use_fp16:
188+
masked_img = masked_img.half()
189+
masked_x = self.encode_first_stage(masked_img[None, ...]).detach()
190+
if self.use_fp16:
191+
masked_x = masked_x.half()
192+
text_info["masked_x"] = torch.cat([masked_x for _ in range(len(prompt))], dim=0)
193+
# hint = self.arr2tensor(np_hint, len(prompt))
194+
63195
self.embedding_manager.encode_text(text_info)
64196
prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager)
65197

0 commit comments

Comments
 (0)