Skip to content

Commit be4a319

Browse files
committed
Up
1 parent f60a72b commit be4a319

File tree

7 files changed

+76
-94
lines changed

7 files changed

+76
-94
lines changed

examples/research_projects/anytext/auxiliary_latent_module.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ def retrieve_latents(
3737
class AuxiliaryLatentModule(nn.Module):
3838
def __init__(self, glyph_channels=1, position_channels=1, model_channels=320, **kwargs):
3939
super().__init__()
40-
self.font = ImageFont.truetype("Arial_Unicode.ttf", 60)
40+
self.font = ImageFont.truetype("font/Arial_Unicode.ttf", 60)
4141
self.use_fp16 = kwargs.get("use_fp16", False)
42-
self.device = kwargs.get("device", "cuda")
42+
self.device = kwargs.get("device", "cpu")
4343

4444
self.glyph_block = nn.Sequential(
4545
nn.Conv2d(glyph_channels, 8, 3, padding=1),
@@ -78,17 +78,22 @@ def __init__(self, glyph_channels=1, position_channels=1, model_channels=320, **
7878
nn.Conv2d(32, 64, 3, padding=1, stride=2),
7979
nn.SiLU(),
8080
)
81-
self.glyph_block.load_state_dict(load_file("glyph_block.safetensors"))
82-
self.position_block.load_state_dict(load_file("position_block.safetensors"))
83-
self.glyph_block = self.glyph_block.to(device="cuda", dtype=torch.float16)
84-
self.position_block = self.position_block.to(device="cuda", dtype=torch.float16)
8581

8682
self.vae = kwargs.get("vae")
8783
self.vae.eval()
8884

8985
self.fuse_block = zero_module(nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1))
90-
self.fuse_block.load_state_dict(load_file("fuse_block.safetensors"))
91-
self.fuse_block = self.fuse_block.to(device="cuda", dtype=torch.float16)
86+
87+
self.glyph_block.load_state_dict(
88+
load_file("AuxiliaryLatentModule/glyph_block.safetensors", device=self.device)
89+
)
90+
self.glyph_block = self.glyph_block.to(dtype=torch.float16 if self.use_fp16 else torch.float32)
91+
self.position_block.load_state_dict(
92+
load_file("AuxiliaryLatentModule/position_block.safetensors", device=self.device)
93+
)
94+
self.position_block = self.position_block.to(dtype=torch.float16 if self.use_fp16 else torch.float32)
95+
self.fuse_block.load_state_dict(load_file("AuxiliaryLatentModule/fuse_block.safetensors", device=self.device))
96+
self.fuse_block = self.fuse_block.to(dtype=torch.float16 if self.use_fp16 else torch.float32)
9297

9398
@torch.no_grad()
9499
def forward(
@@ -121,30 +126,26 @@ def forward(
121126
edit_image = self.resize_image(
122127
edit_image, max_length=768
123128
) # make w h multiple of 64, resize if w or h > max_length
124-
h, w = edit_image.shape[:2] # change h, w by input ref_img
125129

126130
# get masked_x
127131
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
128132
masked_img = np.transpose(masked_img, (2, 0, 1))
129133
masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
130134
if self.use_fp16:
131135
masked_img = masked_img.half()
132-
masked_x = self.encode_first_stage(masked_img[None, ...]).detach()
136+
masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach()
133137
if self.use_fp16:
134138
masked_x = masked_x.half()
135139
text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0)
136140

137141
glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True)
138142
positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True)
139-
enc_glyph = self.glyph_block(glyphs.cuda())
140-
enc_pos = self.position_block(positions.cuda())
141-
guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"].cuda()], dim=1))
143+
enc_glyph = self.glyph_block(glyphs)
144+
enc_pos = self.position_block(positions)
145+
guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1))
142146

143147
return guided_hint
144148

145-
def encode_first_stage(self, masked_img):
146-
return retrieve_latents(self.vae.encode(masked_img)) * self.vae.config.scaling_factor
147-
148149
def check_channels(self, image):
149150
channels = image.shape[2] if len(image.shape) == 3 else 1
150151
if channels == 1:

examples/research_projects/anytext/convert_from_ckpt.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,18 @@
3434
CLIPVisionModelWithProjection,
3535
)
3636

37-
from ...models import (
37+
from diffusers.models import (
3838
AutoencoderKL,
3939
ControlNetModel,
4040
PriorTransformer,
4141
UNet2DConditionModel,
4242
)
43-
from ...schedulers import (
43+
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
44+
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
45+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
46+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
47+
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
48+
from diffusers.schedulers import (
4449
DDIMScheduler,
4550
DDPMScheduler,
4651
DPMSolverMultistepScheduler,
@@ -51,12 +56,7 @@
5156
PNDMScheduler,
5257
UnCLIPScheduler,
5358
)
54-
from ...utils import is_accelerate_available, logging
55-
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
56-
from ..paint_by_example import PaintByExampleImageEncoder
57-
from ..pipeline_utils import DiffusionPipeline
58-
from .safety_checker import StableDiffusionSafetyChecker
59-
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
59+
from diffusers.utils import is_accelerate_available, logging
6060

6161

6262
if is_accelerate_available():

examples/research_projects/anytext/embedding_manager.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -107,25 +107,17 @@ class EmbeddingManager(nn.Module):
107107
def __init__(
108108
self,
109109
embedder,
110-
valid=True,
111-
glyph_channels=20,
112110
position_channels=1,
113111
placeholder_string="*",
114112
add_pos=False,
115113
emb_type="ocr",
116-
**kwargs,
114+
use_fp16=False,
117115
):
118116
super().__init__()
119-
if hasattr(embedder, "tokenizer"): # using Stable Diffusion's CLIP encoder
120-
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
121-
token_dim = 768
122-
if hasattr(embedder, "vit"):
123-
assert emb_type == "vit"
124-
self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor)
125-
self.get_recog_emb = None
126-
else: # using LDM's BERT encoder
127-
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
128-
token_dim = 1280
117+
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
118+
token_dim = 768
119+
self.get_recog_emb = None
120+
token_dim = 1280
129121
self.token_dim = token_dim
130122
self.emb_type = emb_type
131123

@@ -134,9 +126,7 @@ def __init__(
134126
self.position_encoder = EncodeNet(position_channels, token_dim)
135127
if emb_type == "ocr":
136128
self.proj = nn.Sequential(zero_module(nn.Linear(40 * 64, token_dim)), nn.LayerNorm(token_dim))
137-
self.proj = self.proj.to(dtype=torch.float16 if kwargs.get("use_fp16", False) else torch.float32)
138-
if emb_type == "conv":
139-
self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
129+
self.proj = self.proj.to(dtype=torch.float16 if use_fp16 else torch.float32)
140130

141131
self.placeholder_token = get_token_for_string(placeholder_string)
142132

examples/research_projects/anytext/frozen_clip_embedder_t3.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch import nn
3-
from transformers import AutoProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
3+
from transformers import CLIPTextModel, CLIPTokenizer
44
from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
55

66

@@ -16,14 +16,18 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
1616
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
1717

1818
def __init__(
19-
self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, use_vision=False
19+
self,
20+
version="openai/clip-vit-large-patch14",
21+
device="cpu",
22+
max_length=77,
23+
freeze=True,
24+
use_fp16=False,
2025
):
2126
super().__init__()
2227
self.tokenizer = CLIPTokenizer.from_pretrained(version)
23-
self.transformer = CLIPTextModel.from_pretrained(version, torch_dtype=torch.float16).to(device)
24-
if use_vision:
25-
self.vit = CLIPVisionModelWithProjection.from_pretrained(version)
26-
self.processor = AutoProcessor.from_pretrained(version)
28+
self.transformer = CLIPTextModel.from_pretrained(
29+
version, use_safetensors=True, torch_dtype=torch.float16 if use_fp16 else torch.float32
30+
).to(device)
2731
self.device = device
2832
self.max_length = max_length
2933
if freeze:

examples/research_projects/anytext/pipeline_anytext.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,15 @@ def __init__(
218218
feature_extractor: CLIPImageProcessor,
219219
image_encoder: CLIPVisionModelWithProjection = None,
220220
requires_safety_checker: bool = True,
221+
font_path: str = "font/Arial_Unicode.ttf",
221222
):
222223
super().__init__()
223-
self.text_embedding_module = TextEmbeddingModule(use_fp16=unet.dtype == torch.float16)
224-
self.auxiliary_latent_module = AuxiliaryLatentModule(vae=vae, use_fp16=unet.dtype == torch.float16)
224+
self.text_embedding_module = TextEmbeddingModule(
225+
use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path
226+
)
227+
self.auxiliary_latent_module = AuxiliaryLatentModule(
228+
vae=vae, use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path
229+
)
225230

226231
if safety_checker is None and requires_safety_checker:
227232
logger.warning(

examples/research_projects/anytext/recognizer.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -78,37 +78,30 @@ def crop_image(src_img, mask):
7878
return result
7979

8080

81-
def create_predictor(model_dir=None, model_lang="ch", is_onnx=False):
81+
def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
8282
model_file_path = model_dir
8383
if model_file_path is not None and not os.path.exists(model_file_path):
8484
raise ValueError("not find model file path {}".format(model_file_path))
8585

86-
if is_onnx:
87-
import onnxruntime as ort
88-
89-
sess = ort.InferenceSession(
90-
model_file_path, providers=["CPUExecutionProvider"]
91-
) # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
92-
return sess
86+
if model_lang == "ch":
87+
n_class = 6625
88+
elif model_lang == "en":
89+
n_class = 97
9390
else:
94-
if model_lang == "ch":
95-
n_class = 6625
96-
elif model_lang == "en":
97-
n_class = 97
98-
else:
99-
raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
100-
rec_config = edict(
101-
in_channels=3,
102-
backbone=edict(type="MobileNetV1Enhance", scale=0.5, last_conv_stride=[1, 2], last_pool_type="avg"),
103-
neck=edict(type="SequenceEncoder", encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True),
104-
head=edict(type="CTCHead", fc_decay=0.00001, out_channels=n_class, return_feats=True),
91+
raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
92+
rec_config = edict(
93+
in_channels=3,
94+
backbone=edict(type="MobileNetV1Enhance", scale=0.5, last_conv_stride=[1, 2], last_pool_type="avg"),
95+
neck=edict(type="SequenceEncoder", encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True),
96+
head=edict(type="CTCHead", fc_decay=0.00001, out_channels=n_class, return_feats=True),
97+
)
98+
99+
rec_model = RecModel(rec_config)
100+
if model_file_path is not None:
101+
rec_model.load_state_dict(load_file(model_file_path, device=device)).to(
102+
dtype=torch.float16 if use_fp16 else torch.float32
105103
)
106-
107-
rec_model = RecModel(rec_config)
108-
if model_file_path is not None:
109-
rec_model.load_state_dict(load_file(model_file_path))
110-
rec_model.eval()
111-
return rec_model.eval()
104+
return rec_model
112105

113106

114107
def _check_image_file(path):

examples/research_projects/anytext/text_embedding_module.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,25 @@
1616

1717

1818
class TextEmbeddingModule(nn.Module):
19-
def __init__(self, use_fp16):
19+
def __init__(self, font_path, use_fp16=False, device="cpu"):
2020
super().__init__()
2121
self.use_fp16 = use_fp16
22-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
22+
self.device = device
2323
# TODO: Learn if the recommended font file is free to use
24-
self.font = ImageFont.truetype("Arial_Unicode.ttf", 60)
25-
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device)
26-
self.embedding_manager_config = {
27-
"valid": True,
28-
"emb_type": "ocr",
29-
"glyph_channels": 1,
30-
"position_channels": 1,
31-
"add_pos": False,
32-
"placeholder_string": "*",
33-
"use_fp16": self.use_fp16,
34-
}
35-
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, **self.embedding_manager_config)
24+
self.font = ImageFont.truetype(font_path, 60)
25+
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device, use_fp16=self.use_fp16)
26+
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=self.use_fp16)
3627
# TODO: Understand the reason of param.requires_grad = True
3728
for param in self.embedding_manager.embedding_parameters():
3829
param.requires_grad = True
39-
rec_model_dir = "ppv3_rec.safetensors"
40-
self.text_predictor = create_predictor(rec_model_dir).eval()
30+
rec_model_dir = "OCR/ppv3_rec.safetensors"
31+
self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16)
4132
args = {}
4233
args["rec_image_shape"] = "3, 48, 320"
4334
args["rec_batch_num"] = 6
44-
args["rec_char_dict_path"] = "ppocr_keys_v1.txt"
45-
args["use_fp16"] = True
46-
self.cn_recognizer = TextRecognizer(
47-
args, self.text_predictor.to(dtype=torch.float16 if use_fp16 else torch.float32)
48-
)
35+
args["rec_char_dict_path"] = "OCR/ppocr_keys_v1.txt"
36+
args["use_fp16"] = self.use_fp16
37+
self.cn_recognizer = TextRecognizer(args, self.text_predictor, device=self.device, use_fp16=self.use_fp16)
4938
for param in self.text_predictor.parameters():
5039
param.requires_grad = False
5140
self.embedding_manager.recog = self.cn_recognizer

0 commit comments

Comments
 (0)