Skip to content

Commit f347ff2

Browse files
committed
refactor: Simplify
1 parent fdf0275 commit f347ff2

File tree

2 files changed

+17
-105
lines changed

2 files changed

+17
-105
lines changed

examples/research_projects/anytext/auxiliary_latent_module.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,6 @@
1313
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1414

1515

16-
# Copied from diffusers.models.controlnet.zero_module
17-
def zero_module(module: nn.Module) -> nn.Module:
18-
for p in module.parameters():
19-
nn.init.zeros_(p)
20-
return module
21-
22-
2316
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
2417
def retrieve_latents(
2518
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -83,18 +76,20 @@ def __init__(
8376

8477
self.vae = vae.eval()
8578

86-
self.fuse_block = zero_module(nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1))
79+
self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1)
8780

8881
self.glyph_block.load_state_dict(
8982
load_file("AuxiliaryLatentModule/glyph_block.safetensors", device=self.device)
9083
)
91-
self.glyph_block = self.glyph_block.to(dtype=torch.float16 if self.use_fp16 else torch.float32)
9284
self.position_block.load_state_dict(
9385
load_file("AuxiliaryLatentModule/position_block.safetensors", device=self.device)
9486
)
95-
self.position_block = self.position_block.to(dtype=torch.float16 if self.use_fp16 else torch.float32)
9687
self.fuse_block.load_state_dict(load_file("AuxiliaryLatentModule/fuse_block.safetensors", device=self.device))
97-
self.fuse_block = self.fuse_block.to(dtype=torch.float16 if self.use_fp16 else torch.float32)
88+
89+
if use_fp16:
90+
self.glyph_block = self.glyph_block.to(dtype=torch.float16)
91+
self.position_block = self.position_block.to(dtype=torch.float16)
92+
self.fuse_block = self.fuse_block.to(dtype=torch.float16)
9893

9994
@torch.no_grad()
10095
def forward(
@@ -181,6 +176,6 @@ def to(self, device):
181176
self.device = device
182177
self.glyph_block = self.glyph_block.to(device)
183178
self.position_block = self.position_block.to(device)
184-
self.vae = self.vae.to(device)
185179
self.fuse_block = self.fuse_block.to(device)
180+
self.vae = self.vae.to(device)
186181
return self

examples/research_projects/anytext/embedding_manager.py

Lines changed: 10 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,7 @@
55

66
import torch
77
import torch.nn as nn
8-
import torch.nn.functional as F
9-
10-
11-
def conv_nd(dims, *args, **kwargs):
12-
"""
13-
Create a 1D, 2D, or 3D convolution module.
14-
"""
15-
if dims == 1:
16-
return nn.Conv1d(*args, **kwargs)
17-
elif dims == 2:
18-
return nn.Conv2d(*args, **kwargs)
19-
elif dims == 3:
20-
return nn.Conv3d(*args, **kwargs)
21-
raise ValueError(f"unsupported dimensions: {dims}")
22-
23-
24-
# Copied from diffusers.models.controlnet.zero_module
25-
def zero_module(module: nn.Module) -> nn.Module:
26-
for p in module.parameters():
27-
nn.init.zeros_(p)
28-
return module
8+
from safetensors.torch import load_file
299

3010

3111
def get_clip_token_for_string(tokenizer, string):
@@ -45,111 +25,47 @@ def get_clip_token_for_string(tokenizer, string):
4525
return tokens[0, 1]
4626

4727

48-
def get_bert_token_for_string(tokenizer, string):
49-
token = tokenizer(string)
50-
assert (
51-
torch.count_nonzero(token) == 3
52-
), f"String '{string}' maps to more than a single token. Please use another string"
53-
token = token[0, 1]
54-
return token
55-
56-
57-
def get_clip_vision_emb(encoder, processor, img):
58-
_img = img.repeat(1, 3, 1, 1) * 255
59-
inputs = processor(images=_img, return_tensors="pt")
60-
inputs["pixel_values"] = inputs["pixel_values"].to(img.device)
61-
outputs = encoder(**inputs)
62-
emb = outputs.image_embeds
63-
return emb
64-
65-
6628
def get_recog_emb(encoder, img_list):
6729
_img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list]
6830
encoder.predictor.eval()
6931
_, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
7032
return preds_neck
7133

7234

73-
def pad_H(x):
74-
_, _, H, W = x.shape
75-
p_top = (W - H) // 2
76-
p_bot = W - H - p_top
77-
return F.pad(x, (0, 0, p_top, p_bot))
78-
79-
80-
class EncodeNet(nn.Module):
81-
def __init__(self, in_channels, out_channels):
82-
super(EncodeNet, self).__init__()
83-
chan = 16
84-
n_layer = 4 # downsample
85-
86-
self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
87-
self.conv_list = nn.ModuleList([])
88-
_c = chan
89-
for i in range(n_layer):
90-
self.conv_list.append(conv_nd(2, _c, _c * 2, 3, padding=1, stride=2))
91-
_c *= 2
92-
self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
93-
self.avgpool = nn.AdaptiveAvgPool2d(1)
94-
self.act = nn.SiLU()
95-
96-
def forward(self, x):
97-
x = self.act(self.conv1(x))
98-
for layer in self.conv_list:
99-
x = self.act(layer(x))
100-
x = self.act(self.conv2(x))
101-
x = self.avgpool(x)
102-
x = x.view(x.size(0), -1)
103-
return x
104-
105-
10635
class EmbeddingManager(nn.Module):
10736
def __init__(
10837
self,
10938
embedder,
110-
position_channels=1,
11139
placeholder_string="*",
112-
add_pos=False,
113-
emb_type="ocr",
11440
use_fp16=False,
11541
):
11642
super().__init__()
11743
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
11844
token_dim = 768
11945
self.get_recog_emb = None
120-
token_dim = 1280
12146
self.token_dim = token_dim
122-
self.emb_type = emb_type
12347

124-
self.add_pos = add_pos
125-
if add_pos:
126-
self.position_encoder = EncodeNet(position_channels, token_dim)
127-
if emb_type == "ocr":
128-
self.proj = nn.Sequential(zero_module(nn.Linear(40 * 64, token_dim)), nn.LayerNorm(token_dim))
129-
self.proj = self.proj.to(dtype=torch.float16 if use_fp16 else torch.float32)
48+
self.proj = nn.Linear(40 * 64, token_dim)
49+
self.proj.load_state_dict(load_file("EmbeddingManager/embedding_manager.safetensors", device=self.device))
50+
if use_fp16:
51+
self.proj = self.proj.to(dtype=torch.float16)
13052

13153
self.placeholder_token = get_token_for_string(placeholder_string)
13254

55+
@torch.no_grad()
13356
def encode_text(self, text_info):
134-
if self.get_recog_emb is None and self.emb_type == "ocr":
57+
if self.get_recog_emb is None:
13558
self.get_recog_emb = partial(get_recog_emb, self.recog)
13659

13760
gline_list = []
138-
pos_list = []
13961
for i in range(len(text_info["n_lines"])): # sample index in a batch
14062
n_lines = text_info["n_lines"][i]
14163
for j in range(n_lines): # line
14264
gline_list += [text_info["gly_line"][j][i : i + 1]]
143-
if self.add_pos:
144-
pos_list += [text_info["positions"][j][i : i + 1]]
14565

14666
if len(gline_list) > 0:
147-
if self.emb_type == "ocr":
148-
recog_emb = self.get_recog_emb(gline_list)
149-
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
150-
if self.add_pos:
151-
enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
152-
enc_glyph = enc_glyph + enc_pos
67+
recog_emb = self.get_recog_emb(gline_list)
68+
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
15369

15470
self.text_embs_all = []
15571
n_idx = 0
@@ -161,6 +77,7 @@ def encode_text(self, text_info):
16177
n_idx += 1
16278
self.text_embs_all += [text_embs]
16379

80+
@torch.no_grad()
16481
def forward(
16582
self,
16683
tokenized_text,

0 commit comments

Comments
 (0)