Skip to content

Commit a8dbbe2

Browse files
committed
Up
1 parent af30f0f commit a8dbbe2

File tree

5 files changed

+251
-158
lines changed

5 files changed

+251
-158
lines changed

examples/research_projects/anytext/auxiliary_latent_module.py

Lines changed: 101 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
# +> fuse layer
33
# position l_p -> position block ->
44

5+
import math
56
from typing import Optional
67

78
import cv2
89
import numpy as np
910
import torch
10-
from PIL import Image, ImageDraw, ImageFont
11+
from einops import repeat
12+
from PIL import ImageFont
1113
from torch import nn
1214

1315
from diffusers.utils import logging
@@ -16,19 +18,6 @@
1618
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1719

1820

19-
def conv_nd(dims, *args, **kwargs):
20-
"""
21-
Create a 1D, 2D, or 3D convolution module.
22-
"""
23-
if dims == 1:
24-
return nn.Conv1d(*args, **kwargs)
25-
elif dims == 2:
26-
return nn.Conv2d(*args, **kwargs)
27-
elif dims == 3:
28-
return nn.Conv3d(*args, **kwargs)
29-
raise ValueError(f"unsupported dimensions: {dims}")
30-
31-
3221
# Copied from diffusers.models.controlnet.zero_module
3322
def zero_module(module: nn.Module) -> nn.Module:
3423
for p in module.parameters():
@@ -56,74 +45,142 @@ def __init__(self, dims=2, glyph_channels=1, position_channels=1, model_channels
5645
self.font = ImageFont.truetype("/home/cosmos/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60)
5746
self.use_fp16 = kwargs.get("use_fp16", False)
5847
self.device = kwargs.get("device", "cpu")
48+
self.model_channels = model_channels
49+
time_embed_dim = model_channels * 4
50+
self.time_embed = nn.Sequential(
51+
nn.Linear(model_channels, time_embed_dim),
52+
nn.SiLU(),
53+
nn.Linear(time_embed_dim, time_embed_dim),
54+
)
5955
self.glyph_block = nn.Sequential(
60-
conv_nd(dims, glyph_channels, 8, 3, padding=1),
56+
nn.Conv2d(glyph_channels, 8, 3, padding=1),
6157
nn.SiLU(),
62-
conv_nd(dims, 8, 8, 3, padding=1),
58+
nn.Conv2d(8, 8, 3, padding=1),
6359
nn.SiLU(),
64-
conv_nd(dims, 8, 16, 3, padding=1, stride=2),
60+
nn.Conv2d(8, 16, 3, padding=1, stride=2),
6561
nn.SiLU(),
66-
conv_nd(dims, 16, 16, 3, padding=1),
62+
nn.Conv2d(16, 16, 3, padding=1),
6763
nn.SiLU(),
68-
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
64+
nn.Conv2d(16, 32, 3, padding=1, stride=2),
6965
nn.SiLU(),
70-
conv_nd(dims, 32, 32, 3, padding=1),
66+
nn.Conv2d(32, 32, 3, padding=1),
7167
nn.SiLU(),
72-
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
68+
nn.Conv2d(32, 96, 3, padding=1, stride=2),
7369
nn.SiLU(),
74-
conv_nd(dims, 96, 96, 3, padding=1),
70+
nn.Conv2d(96, 96, 3, padding=1),
7571
nn.SiLU(),
76-
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
72+
nn.Conv2d(96, 256, 3, padding=1, stride=2),
7773
nn.SiLU(),
7874
)
7975

8076
self.position_block = nn.Sequential(
81-
conv_nd(dims, position_channels, 8, 3, padding=1),
77+
nn.Conv2d(position_channels, 8, 3, padding=1),
8278
nn.SiLU(),
83-
conv_nd(dims, 8, 8, 3, padding=1),
79+
nn.Conv2d(8, 8, 3, padding=1),
8480
nn.SiLU(),
85-
conv_nd(dims, 8, 16, 3, padding=1, stride=2),
81+
nn.Conv2d(8, 16, 3, padding=1, stride=2),
8682
nn.SiLU(),
87-
conv_nd(dims, 16, 16, 3, padding=1),
83+
nn.Conv2d(16, 16, 3, padding=1),
8884
nn.SiLU(),
89-
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
85+
nn.Conv2d(16, 32, 3, padding=1, stride=2),
9086
nn.SiLU(),
91-
conv_nd(dims, 32, 32, 3, padding=1),
87+
nn.Conv2d(32, 32, 3, padding=1),
9288
nn.SiLU(),
93-
conv_nd(dims, 32, 64, 3, padding=1, stride=2),
89+
nn.Conv2d(32, 64, 3, padding=1, stride=2),
9490
nn.SiLU(),
9591
)
92+
self.time_embed = self.time_embed.to(device="cuda", dtype=torch.float16)
93+
self.glyph_block = self.glyph_block.to(device="cuda", dtype=torch.float16)
94+
self.position_block = self.position_block.to(device="cuda", dtype=torch.float16)
9695

9796
self.vae = kwargs.get("vae")
9897
self.vae.eval()
9998

100-
self.fuse_block = zero_module(conv_nd(dims, 256 + 64 + 4, model_channels, 3, padding=1))
99+
self.fuse_block = zero_module(nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1))
100+
self.fuse_block = self.fuse_block.to(device="cuda", dtype=torch.float16)
101101

102102
@torch.no_grad()
103103
def forward(
104104
self,
105-
emb,
106105
context,
107106
text_info,
107+
mode,
108+
draw_pos,
109+
ori_image,
110+
num_images_per_prompt,
111+
np_hint,
112+
h=512,
113+
w=512,
108114
):
115+
if mode == "generate":
116+
edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
117+
elif mode == "edit":
118+
if draw_pos is None or ori_image is None:
119+
raise ValueError("Reference image and position image are needed for text editing!")
120+
if isinstance(ori_image, str):
121+
ori_image = cv2.imread(ori_image)[..., ::-1]
122+
if ori_image is None:
123+
raise ValueError(f"Can't read ori_image image from {ori_image}!")
124+
elif isinstance(ori_image, torch.Tensor):
125+
ori_image = ori_image.cpu().numpy()
126+
else:
127+
if not isinstance(ori_image, np.ndarray):
128+
raise ValueError(f"Unknown format of ori_image: {type(ori_image)}")
129+
edit_image = ori_image.clip(1, 255) # for mask reason
130+
edit_image = self.check_channels(edit_image)
131+
edit_image = self.resize_image(
132+
edit_image, max_length=768
133+
) # make w h multiple of 64, resize if w or h > max_length
134+
h, w = edit_image.shape[:2] # change h, w by input ref_img
135+
136+
# get masked_x
137+
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
138+
masked_img = np.transpose(masked_img, (2, 0, 1))
139+
masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
140+
if self.use_fp16:
141+
masked_img = masked_img.half()
142+
masked_x = self.encode_first_stage(masked_img[None, ...]).detach()
143+
if self.use_fp16:
144+
masked_x = masked_x.half()
145+
text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0)
146+
109147
glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True)
110148
positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True)
111-
enc_glyph = self.glyph_block(glyphs, emb, context)
112-
enc_pos = self.position_block(positions, emb, context)
113-
guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1))
149+
t_emb = self.timestep_embedding(torch.tensor([1000], device="cuda"), self.model_channels, repeat_only=False)
150+
if self.use_fp16:
151+
t_emb = t_emb.half()
152+
emb = self.time_embed(t_emb)
153+
print(glyphs.shape, emb.shape, positions.shape, context.shape)
154+
enc_glyph = self.glyph_block(glyphs.cuda(), emb, context)
155+
enc_pos = self.position_block(positions.cuda(), emb, context)
156+
guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"].cuda()], dim=1))
114157

115158
return guided_hint
116159

117-
def encode_first_stage(self, masked_img):
118-
return retrieve_latents(self.vae.encode(masked_img)) * self.vae.scale_factor
160+
def timestep_embedding(self, timesteps, dim, max_period=10000, repeat_only=False):
161+
"""
162+
Create sinusoidal timestep embeddings.
163+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
164+
These may be fractional.
165+
:param dim: the dimension of the output.
166+
:param max_period: controls the minimum frequency of the embeddings.
167+
:return: an [N x dim] Tensor of positional embeddings.
168+
"""
169+
if not repeat_only:
170+
half = dim // 2
171+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
172+
device=timesteps.device
173+
)
174+
args = timesteps[:, None].float() * freqs[None]
175+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
176+
if dim % 2:
177+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
178+
else:
179+
embedding = repeat(timesteps, "b -> b d", d=dim)
180+
return embedding
119181

120-
def arr2tensor(self, arr, bs):
121-
arr = np.transpose(arr, (2, 0, 1))
122-
_arr = torch.from_numpy(arr.copy()).float().cpu()
123-
if self.use_fp16:
124-
_arr = _arr.half()
125-
_arr = torch.stack([_arr for _ in range(bs)], dim=0)
126-
return _arr
182+
def encode_first_stage(self, masked_img):
183+
return retrieve_latents(self.vae.encode(masked_img)) * self.vae.config.scaling_factor
127184

128185
def check_channels(self, image):
129186
channels = image.shape[2] if len(image.shape) == 3 else 1
@@ -155,79 +212,6 @@ def insert_spaces(self, string, nSpace):
155212
new_string += char + " " * nSpace
156213
return new_string[:-nSpace]
157214

158-
def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True):
159-
enlarge_polygon = polygon * scale
160-
rect = cv2.minAreaRect(enlarge_polygon)
161-
box = cv2.boxPoints(rect)
162-
box = np.int0(box)
163-
w, h = rect[1]
164-
angle = rect[2]
165-
if angle < -45:
166-
angle += 90
167-
angle = -angle
168-
if w < h:
169-
angle += 90
170-
171-
vert = False
172-
if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng:
173-
_w = max(box[:, 0]) - min(box[:, 0])
174-
_h = max(box[:, 1]) - min(box[:, 1])
175-
if _h >= _w:
176-
vert = True
177-
angle = 0
178-
179-
img = np.zeros((height * scale, width * scale, 3), np.uint8)
180-
img = Image.fromarray(img)
181-
182-
# infer font size
183-
image4ratio = Image.new("RGB", img.size, "white")
184-
draw = ImageDraw.Draw(image4ratio)
185-
_, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)
186-
text_w = min(w, h) * (_tw / _th)
187-
if text_w <= max(w, h):
188-
# add space
189-
if len(text) > 1 and not vert and add_space:
190-
for i in range(1, 100):
191-
text_space = self.insert_spaces(text, i)
192-
_, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font)
193-
if min(w, h) * (_tw2 / _th2) > max(w, h):
194-
break
195-
text = self.insert_spaces(text, i - 1)
196-
font_size = min(w, h) * 0.80
197-
else:
198-
shrink = 0.75 if vert else 0.85
199-
font_size = min(w, h) / (text_w / max(w, h)) * shrink
200-
new_font = font.font_variant(size=int(font_size))
201-
202-
left, top, right, bottom = new_font.getbbox(text)
203-
text_width = right - left
204-
text_height = bottom - top
205-
206-
layer = Image.new("RGBA", img.size, (0, 0, 0, 0))
207-
draw = ImageDraw.Draw(layer)
208-
if not vert:
209-
draw.text(
210-
(rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top),
211-
text,
212-
font=new_font,
213-
fill=(255, 255, 255, 255),
214-
)
215-
else:
216-
x_s = min(box[:, 0]) + _w // 2 - text_height // 2
217-
y_s = min(box[:, 1])
218-
for c in text:
219-
draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))
220-
_, _t, _, _b = new_font.getbbox(c)
221-
y_s += _b
222-
223-
rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1]))
224-
225-
x_offset = int((img.width - rotated_layer.width) / 2)
226-
y_offset = int((img.height - rotated_layer.height) / 2)
227-
img.paste(rotated_layer, (x_offset, y_offset), rotated_layer)
228-
img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64)
229-
return img
230-
231215
def to(self, device):
232216
self.device = device
233217
self.glyph_block = self.glyph_block.to(device)

examples/research_projects/anytext/embedding_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def forward(
185185
text_emb = torch.cat(self.text_embs_all[i], dim=0)
186186
if sum(idx) != len(text_emb):
187187
print("truncation for long caption...")
188+
text_emb = text_emb.to(embedded_text.device)
188189
embedded_text[i][idx] = text_emb[: sum(idx)]
189190
return embedded_text
190191

examples/research_projects/anytext/frozen_clip_embedder_t3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(
2020
):
2121
super().__init__()
2222
self.tokenizer = CLIPTokenizer.from_pretrained(version)
23-
self.transformer = CLIPTextModel.from_pretrained(version)
23+
self.transformer = CLIPTextModel.from_pretrained(version).to(device)
2424
if use_vision:
2525
self.vit = CLIPVisionModelWithProjection.from_pretrained(version)
2626
self.processor = AutoProcessor.from_pretrained(version)

examples/research_projects/anytext/pipeline_anytext.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,14 +1118,13 @@ def __call__(
11181118
text_encoder_lora_scale = (
11191119
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
11201120
)
1121-
prompt_embeds, negative_prompt_embeds, text_info = self.text_embedding_module(
1121+
prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_embedding_module(
11221122
prompt,
11231123
texts,
11241124
negative_prompt,
11251125
num_images_per_prompt,
11261126
mode,
11271127
draw_pos,
1128-
ori_image,
11291128
)
11301129

11311130
# For classifier free guidance, we need to do two forward passes.
@@ -1166,9 +1165,13 @@ def __call__(
11661165
# )
11671166
# height, width = image.shape[-2:]
11681167
guided_hint = self.auxiliary_latent_module(
1169-
emb=timestep_cond,
1170-
context=prompt_embeds,
1168+
context=prompt_embeds[1],
11711169
text_info=text_info,
1170+
mode=mode,
1171+
draw_pos=draw_pos,
1172+
ori_image=ori_image,
1173+
num_images_per_prompt=num_images_per_prompt,
1174+
np_hint=np_hint,
11721175
)
11731176
# elif isinstance(controlnet, MultiControlNetModel):
11741177
# images = []

0 commit comments

Comments
 (0)