|
2 | 2 | # +> fuse layer |
3 | 3 | # position l_p -> position block -> |
4 | 4 |
|
| 5 | +import math |
5 | 6 | from typing import Optional |
6 | 7 |
|
7 | 8 | import cv2 |
8 | 9 | import numpy as np |
9 | 10 | import torch |
10 | | -from PIL import Image, ImageDraw, ImageFont |
| 11 | +from einops import repeat |
| 12 | +from PIL import ImageFont |
11 | 13 | from torch import nn |
12 | 14 |
|
13 | 15 | from diffusers.utils import logging |
|
16 | 18 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
17 | 19 |
|
18 | 20 |
|
19 | | -def conv_nd(dims, *args, **kwargs): |
20 | | - """ |
21 | | - Create a 1D, 2D, or 3D convolution module. |
22 | | - """ |
23 | | - if dims == 1: |
24 | | - return nn.Conv1d(*args, **kwargs) |
25 | | - elif dims == 2: |
26 | | - return nn.Conv2d(*args, **kwargs) |
27 | | - elif dims == 3: |
28 | | - return nn.Conv3d(*args, **kwargs) |
29 | | - raise ValueError(f"unsupported dimensions: {dims}") |
30 | | - |
31 | | - |
32 | 21 | # Copied from diffusers.models.controlnet.zero_module |
33 | 22 | def zero_module(module: nn.Module) -> nn.Module: |
34 | 23 | for p in module.parameters(): |
@@ -56,74 +45,142 @@ def __init__(self, dims=2, glyph_channels=1, position_channels=1, model_channels |
56 | 45 | self.font = ImageFont.truetype("/home/cosmos/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60) |
57 | 46 | self.use_fp16 = kwargs.get("use_fp16", False) |
58 | 47 | self.device = kwargs.get("device", "cpu") |
| 48 | + self.model_channels = model_channels |
| 49 | + time_embed_dim = model_channels * 4 |
| 50 | + self.time_embed = nn.Sequential( |
| 51 | + nn.Linear(model_channels, time_embed_dim), |
| 52 | + nn.SiLU(), |
| 53 | + nn.Linear(time_embed_dim, time_embed_dim), |
| 54 | + ) |
59 | 55 | self.glyph_block = nn.Sequential( |
60 | | - conv_nd(dims, glyph_channels, 8, 3, padding=1), |
| 56 | + nn.Conv2d(glyph_channels, 8, 3, padding=1), |
61 | 57 | nn.SiLU(), |
62 | | - conv_nd(dims, 8, 8, 3, padding=1), |
| 58 | + nn.Conv2d(8, 8, 3, padding=1), |
63 | 59 | nn.SiLU(), |
64 | | - conv_nd(dims, 8, 16, 3, padding=1, stride=2), |
| 60 | + nn.Conv2d(8, 16, 3, padding=1, stride=2), |
65 | 61 | nn.SiLU(), |
66 | | - conv_nd(dims, 16, 16, 3, padding=1), |
| 62 | + nn.Conv2d(16, 16, 3, padding=1), |
67 | 63 | nn.SiLU(), |
68 | | - conv_nd(dims, 16, 32, 3, padding=1, stride=2), |
| 64 | + nn.Conv2d(16, 32, 3, padding=1, stride=2), |
69 | 65 | nn.SiLU(), |
70 | | - conv_nd(dims, 32, 32, 3, padding=1), |
| 66 | + nn.Conv2d(32, 32, 3, padding=1), |
71 | 67 | nn.SiLU(), |
72 | | - conv_nd(dims, 32, 96, 3, padding=1, stride=2), |
| 68 | + nn.Conv2d(32, 96, 3, padding=1, stride=2), |
73 | 69 | nn.SiLU(), |
74 | | - conv_nd(dims, 96, 96, 3, padding=1), |
| 70 | + nn.Conv2d(96, 96, 3, padding=1), |
75 | 71 | nn.SiLU(), |
76 | | - conv_nd(dims, 96, 256, 3, padding=1, stride=2), |
| 72 | + nn.Conv2d(96, 256, 3, padding=1, stride=2), |
77 | 73 | nn.SiLU(), |
78 | 74 | ) |
79 | 75 |
|
80 | 76 | self.position_block = nn.Sequential( |
81 | | - conv_nd(dims, position_channels, 8, 3, padding=1), |
| 77 | + nn.Conv2d(position_channels, 8, 3, padding=1), |
82 | 78 | nn.SiLU(), |
83 | | - conv_nd(dims, 8, 8, 3, padding=1), |
| 79 | + nn.Conv2d(8, 8, 3, padding=1), |
84 | 80 | nn.SiLU(), |
85 | | - conv_nd(dims, 8, 16, 3, padding=1, stride=2), |
| 81 | + nn.Conv2d(8, 16, 3, padding=1, stride=2), |
86 | 82 | nn.SiLU(), |
87 | | - conv_nd(dims, 16, 16, 3, padding=1), |
| 83 | + nn.Conv2d(16, 16, 3, padding=1), |
88 | 84 | nn.SiLU(), |
89 | | - conv_nd(dims, 16, 32, 3, padding=1, stride=2), |
| 85 | + nn.Conv2d(16, 32, 3, padding=1, stride=2), |
90 | 86 | nn.SiLU(), |
91 | | - conv_nd(dims, 32, 32, 3, padding=1), |
| 87 | + nn.Conv2d(32, 32, 3, padding=1), |
92 | 88 | nn.SiLU(), |
93 | | - conv_nd(dims, 32, 64, 3, padding=1, stride=2), |
| 89 | + nn.Conv2d(32, 64, 3, padding=1, stride=2), |
94 | 90 | nn.SiLU(), |
95 | 91 | ) |
| 92 | + self.time_embed = self.time_embed.to(device="cuda", dtype=torch.float16) |
| 93 | + self.glyph_block = self.glyph_block.to(device="cuda", dtype=torch.float16) |
| 94 | + self.position_block = self.position_block.to(device="cuda", dtype=torch.float16) |
96 | 95 |
|
97 | 96 | self.vae = kwargs.get("vae") |
98 | 97 | self.vae.eval() |
99 | 98 |
|
100 | | - self.fuse_block = zero_module(conv_nd(dims, 256 + 64 + 4, model_channels, 3, padding=1)) |
| 99 | + self.fuse_block = zero_module(nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1)) |
| 100 | + self.fuse_block = self.fuse_block.to(device="cuda", dtype=torch.float16) |
101 | 101 |
|
102 | 102 | @torch.no_grad() |
103 | 103 | def forward( |
104 | 104 | self, |
105 | | - emb, |
106 | 105 | context, |
107 | 106 | text_info, |
| 107 | + mode, |
| 108 | + draw_pos, |
| 109 | + ori_image, |
| 110 | + num_images_per_prompt, |
| 111 | + np_hint, |
| 112 | + h=512, |
| 113 | + w=512, |
108 | 114 | ): |
| 115 | + if mode == "generate": |
| 116 | + edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image |
| 117 | + elif mode == "edit": |
| 118 | + if draw_pos is None or ori_image is None: |
| 119 | + raise ValueError("Reference image and position image are needed for text editing!") |
| 120 | + if isinstance(ori_image, str): |
| 121 | + ori_image = cv2.imread(ori_image)[..., ::-1] |
| 122 | + if ori_image is None: |
| 123 | + raise ValueError(f"Can't read ori_image image from {ori_image}!") |
| 124 | + elif isinstance(ori_image, torch.Tensor): |
| 125 | + ori_image = ori_image.cpu().numpy() |
| 126 | + else: |
| 127 | + if not isinstance(ori_image, np.ndarray): |
| 128 | + raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") |
| 129 | + edit_image = ori_image.clip(1, 255) # for mask reason |
| 130 | + edit_image = self.check_channels(edit_image) |
| 131 | + edit_image = self.resize_image( |
| 132 | + edit_image, max_length=768 |
| 133 | + ) # make w h multiple of 64, resize if w or h > max_length |
| 134 | + h, w = edit_image.shape[:2] # change h, w by input ref_img |
| 135 | + |
| 136 | + # get masked_x |
| 137 | + masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) |
| 138 | + masked_img = np.transpose(masked_img, (2, 0, 1)) |
| 139 | + masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device) |
| 140 | + if self.use_fp16: |
| 141 | + masked_img = masked_img.half() |
| 142 | + masked_x = self.encode_first_stage(masked_img[None, ...]).detach() |
| 143 | + if self.use_fp16: |
| 144 | + masked_x = masked_x.half() |
| 145 | + text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0) |
| 146 | + |
109 | 147 | glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) |
110 | 148 | positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) |
111 | | - enc_glyph = self.glyph_block(glyphs, emb, context) |
112 | | - enc_pos = self.position_block(positions, emb, context) |
113 | | - guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1)) |
| 149 | + t_emb = self.timestep_embedding(torch.tensor([1000], device="cuda"), self.model_channels, repeat_only=False) |
| 150 | + if self.use_fp16: |
| 151 | + t_emb = t_emb.half() |
| 152 | + emb = self.time_embed(t_emb) |
| 153 | + print(glyphs.shape, emb.shape, positions.shape, context.shape) |
| 154 | + enc_glyph = self.glyph_block(glyphs.cuda(), emb, context) |
| 155 | + enc_pos = self.position_block(positions.cuda(), emb, context) |
| 156 | + guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"].cuda()], dim=1)) |
114 | 157 |
|
115 | 158 | return guided_hint |
116 | 159 |
|
117 | | - def encode_first_stage(self, masked_img): |
118 | | - return retrieve_latents(self.vae.encode(masked_img)) * self.vae.scale_factor |
| 160 | + def timestep_embedding(self, timesteps, dim, max_period=10000, repeat_only=False): |
| 161 | + """ |
| 162 | + Create sinusoidal timestep embeddings. |
| 163 | + :param timesteps: a 1-D Tensor of N indices, one per batch element. |
| 164 | + These may be fractional. |
| 165 | + :param dim: the dimension of the output. |
| 166 | + :param max_period: controls the minimum frequency of the embeddings. |
| 167 | + :return: an [N x dim] Tensor of positional embeddings. |
| 168 | + """ |
| 169 | + if not repeat_only: |
| 170 | + half = dim // 2 |
| 171 | + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( |
| 172 | + device=timesteps.device |
| 173 | + ) |
| 174 | + args = timesteps[:, None].float() * freqs[None] |
| 175 | + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| 176 | + if dim % 2: |
| 177 | + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| 178 | + else: |
| 179 | + embedding = repeat(timesteps, "b -> b d", d=dim) |
| 180 | + return embedding |
119 | 181 |
|
120 | | - def arr2tensor(self, arr, bs): |
121 | | - arr = np.transpose(arr, (2, 0, 1)) |
122 | | - _arr = torch.from_numpy(arr.copy()).float().cpu() |
123 | | - if self.use_fp16: |
124 | | - _arr = _arr.half() |
125 | | - _arr = torch.stack([_arr for _ in range(bs)], dim=0) |
126 | | - return _arr |
| 182 | + def encode_first_stage(self, masked_img): |
| 183 | + return retrieve_latents(self.vae.encode(masked_img)) * self.vae.config.scaling_factor |
127 | 184 |
|
128 | 185 | def check_channels(self, image): |
129 | 186 | channels = image.shape[2] if len(image.shape) == 3 else 1 |
@@ -155,79 +212,6 @@ def insert_spaces(self, string, nSpace): |
155 | 212 | new_string += char + " " * nSpace |
156 | 213 | return new_string[:-nSpace] |
157 | 214 |
|
158 | | - def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): |
159 | | - enlarge_polygon = polygon * scale |
160 | | - rect = cv2.minAreaRect(enlarge_polygon) |
161 | | - box = cv2.boxPoints(rect) |
162 | | - box = np.int0(box) |
163 | | - w, h = rect[1] |
164 | | - angle = rect[2] |
165 | | - if angle < -45: |
166 | | - angle += 90 |
167 | | - angle = -angle |
168 | | - if w < h: |
169 | | - angle += 90 |
170 | | - |
171 | | - vert = False |
172 | | - if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: |
173 | | - _w = max(box[:, 0]) - min(box[:, 0]) |
174 | | - _h = max(box[:, 1]) - min(box[:, 1]) |
175 | | - if _h >= _w: |
176 | | - vert = True |
177 | | - angle = 0 |
178 | | - |
179 | | - img = np.zeros((height * scale, width * scale, 3), np.uint8) |
180 | | - img = Image.fromarray(img) |
181 | | - |
182 | | - # infer font size |
183 | | - image4ratio = Image.new("RGB", img.size, "white") |
184 | | - draw = ImageDraw.Draw(image4ratio) |
185 | | - _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) |
186 | | - text_w = min(w, h) * (_tw / _th) |
187 | | - if text_w <= max(w, h): |
188 | | - # add space |
189 | | - if len(text) > 1 and not vert and add_space: |
190 | | - for i in range(1, 100): |
191 | | - text_space = self.insert_spaces(text, i) |
192 | | - _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) |
193 | | - if min(w, h) * (_tw2 / _th2) > max(w, h): |
194 | | - break |
195 | | - text = self.insert_spaces(text, i - 1) |
196 | | - font_size = min(w, h) * 0.80 |
197 | | - else: |
198 | | - shrink = 0.75 if vert else 0.85 |
199 | | - font_size = min(w, h) / (text_w / max(w, h)) * shrink |
200 | | - new_font = font.font_variant(size=int(font_size)) |
201 | | - |
202 | | - left, top, right, bottom = new_font.getbbox(text) |
203 | | - text_width = right - left |
204 | | - text_height = bottom - top |
205 | | - |
206 | | - layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) |
207 | | - draw = ImageDraw.Draw(layer) |
208 | | - if not vert: |
209 | | - draw.text( |
210 | | - (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), |
211 | | - text, |
212 | | - font=new_font, |
213 | | - fill=(255, 255, 255, 255), |
214 | | - ) |
215 | | - else: |
216 | | - x_s = min(box[:, 0]) + _w // 2 - text_height // 2 |
217 | | - y_s = min(box[:, 1]) |
218 | | - for c in text: |
219 | | - draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) |
220 | | - _, _t, _, _b = new_font.getbbox(c) |
221 | | - y_s += _b |
222 | | - |
223 | | - rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) |
224 | | - |
225 | | - x_offset = int((img.width - rotated_layer.width) / 2) |
226 | | - y_offset = int((img.height - rotated_layer.height) / 2) |
227 | | - img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) |
228 | | - img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) |
229 | | - return img |
230 | | - |
231 | 215 | def to(self, device): |
232 | 216 | self.device = device |
233 | 217 | self.glyph_block = self.glyph_block.to(device) |
|
0 commit comments