Skip to content

Commit 13b7ecf

Browse files
committed
🆙
1 parent 0d44b5b commit 13b7ecf

File tree

1 file changed

+91
-82
lines changed

1 file changed

+91
-82
lines changed

examples/research_projects/anytext/anytext.py

Lines changed: 91 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import PIL.Image
3535
import torch
3636
import torch.nn.functional as F
37-
from easydict import EasyDict as edict
3837
from huggingface_hub import hf_hub_download
3938
from ocr_recog.RecModel import RecModel
4039
from PIL import Image, ImageDraw, ImageFont
@@ -58,6 +57,8 @@
5857
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
5958
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6059
from diffusers.schedulers import KarrasDiffusionSchedulers
60+
from diffusers.configuration_utils import register_to_config, ConfigMixin
61+
from diffusers.models.modeling_utils import ModelMixin
6162
from diffusers.utils import (
6263
USE_PEFT_BACKEND,
6364
deprecate,
@@ -203,18 +204,18 @@ def get_recog_emb(encoder, img_list):
203204
return preds_neck
204205

205206

206-
class EmbeddingManager(nn.Module):
207+
class EmbeddingManager(ModelMixin, ConfigMixin):
208+
@register_to_config
207209
def __init__(
208210
self,
209211
embedder,
210212
placeholder_string="*",
211213
use_fp16=False,
214+
token_dim = 768,
215+
get_recog_emb = None,
212216
):
213217
super().__init__()
214218
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
215-
token_dim = 768
216-
self.get_recog_emb = None
217-
self.token_dim = token_dim
218219

219220
self.proj = nn.Linear(40 * 64, token_dim)
220221
proj_dir = hf_hub_download(
@@ -226,12 +227,14 @@ def __init__(
226227
if use_fp16:
227228
self.proj = self.proj.to(dtype=torch.float16)
228229

230+
# self.register_parameter("proj", proj)
229231
self.placeholder_token = get_token_for_string(placeholder_string)
232+
# self.register_config(placeholder_token=placeholder_token)
230233

231234
@torch.no_grad()
232235
def encode_text(self, text_info):
233-
if self.get_recog_emb is None:
234-
self.get_recog_emb = partial(get_recog_emb, self.recog)
236+
if self.config.get_recog_emb is None:
237+
self.config.get_recog_emb = partial(get_recog_emb, self.recog)
235238

236239
gline_list = []
237240
for i in range(len(text_info["n_lines"])): # sample index in a batch
@@ -240,7 +243,7 @@ def encode_text(self, text_info):
240243
gline_list += [text_info["gly_line"][j][i : i + 1]]
241244

242245
if len(gline_list) > 0:
243-
recog_emb = self.get_recog_emb(gline_list)
246+
recog_emb = self.config.get_recog_emb(gline_list)
244247
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype))
245248

246249
self.text_embs_all = []
@@ -332,13 +335,12 @@ def crop_image(src_img, mask):
332335
return result
333336

334337

335-
def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
336-
if model_dir is None or not os.path.exists(model_dir):
337-
model_dir = hf_hub_download(
338-
repo_id="tolgacangoz/anytext",
339-
filename="text_embedding_module/OCR/ppv3_rec.pth",
340-
cache_dir=HF_MODULES_CACHE,
341-
)
338+
def create_predictor(model_lang="ch", device="cpu", use_fp16=False):
339+
model_dir = hf_hub_download(
340+
repo_id="tolgacangoz/anytext",
341+
filename="text_embedding_module/OCR/ppv3_rec.pth",
342+
cache_dir=HF_MODULES_CACHE,
343+
)
342344
if not os.path.exists(model_dir):
343345
raise ValueError("not find model file path {}".format(model_dir))
344346

@@ -533,24 +535,24 @@ def encode(self, *args, **kwargs):
533535
raise NotImplementedError
534536

535537

536-
class FrozenCLIPEmbedderT3(AbstractEncoder):
538+
class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):
537539
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
538-
540+
@register_to_config
539541
def __init__(
540542
self,
541-
version="openai/clip-vit-large-patch14",
542543
device="cpu",
543544
max_length=77,
544545
freeze=True,
545546
use_fp16=False,
547+
variant: Optional[str] = None,
546548
):
547549
super().__init__()
548-
self.tokenizer = CLIPTokenizer.from_pretrained(version)
549-
self.transformer = CLIPTextModel.from_pretrained(
550-
version, use_safetensors=True, torch_dtype=torch.float16 if use_fp16 else torch.float32
551-
).to(device)
552-
self.device = device
553-
self.max_length = max_length
550+
self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
551+
self.transformer = CLIPTextModel.from_pretrained("tolgacangoz/anytext", subfolder="text_encoder",
552+
torch_dtype=torch.float16 if use_fp16 else torch.float32,
553+
variant="fp16" if use_fp16 else None)
554+
# self.device = device
555+
# self.max_length = max_length
554556
if freeze:
555557
self.freeze()
556558

@@ -686,7 +688,7 @@ def forward(self, text, **kwargs):
686688
batch_encoding = self.tokenizer(
687689
text,
688690
truncation=False,
689-
max_length=self.max_length,
691+
max_length=self.config.max_length,
690692
return_length=True,
691693
return_overflowing_tokens=False,
692694
padding="longest",
@@ -729,34 +731,39 @@ def split_chunks(self, input_ids, chunk_size=75):
729731
tokens_list.append(remaining_group_pad)
730732
return tokens_list
731733

732-
def to(self, *args, **kwargs):
733-
self.transformer = self.transformer.to(*args, **kwargs)
734-
self.device = self.transformer.device
735-
return self
734+
# def to(self, *args, **kwargs):
735+
# self.transformer = self.transformer.to(*args, **kwargs)
736+
# self.device = self.transformer.device
737+
# return self
736738

737739

738-
class TextEmbeddingModule(nn.Module):
740+
class TextEmbeddingModule(ModelMixin, ConfigMixin):
741+
@register_to_config
739742
def __init__(self, font_path, use_fp16=False, device="cpu"):
740743
super().__init__()
741-
self.font = ImageFont.truetype(font_path, 60)
742-
self.use_fp16 = use_fp16
743-
self.device = device
744+
font = ImageFont.truetype(font_path, 60)
745+
746+
# self.use_fp16 = use_fp16
747+
# self.device = device
744748
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
745749
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
746-
rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
747-
self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
748-
args = {}
749-
args["rec_image_shape"] = "3, 48, 320"
750-
args["rec_batch_num"] = 6
751-
args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt"
752-
args["rec_char_dict_path"] = hf_hub_download(
753-
repo_id="tolgacangoz/anytext",
754-
filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
755-
cache_dir=HF_MODULES_CACHE,
756-
)
757-
args["use_fp16"] = use_fp16
750+
self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval()
751+
args = {"rec_image_shape": "3, 48, 320",
752+
"rec_batch_num": 6,
753+
"rec_char_dict_path": hf_hub_download(
754+
repo_id="tolgacangoz/anytext",
755+
filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
756+
cache_dir=HF_MODULES_CACHE,
757+
),
758+
"use_fp16": use_fp16}
758759
self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
759760

761+
# self.register_modules(
762+
# frozen_CLIP_embedder_t3=frozen_CLIP_embedder_t3,
763+
# embedding_manager=embedding_manager,
764+
# )
765+
self.register_to_config(font=font)
766+
760767
@torch.no_grad()
761768
def forward(
762769
self,
@@ -837,9 +844,9 @@ def forward(
837844
text = text[:max_chars]
838845
gly_scale = 2
839846
if pre_pos[i].mean() != 0:
840-
gly_line = self.draw_glyph(self.font, text)
847+
gly_line = self.draw_glyph(self.config.font, text)
841848
glyphs = self.draw_glyph2(
842-
self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False
849+
self.config.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False
843850
)
844851
if revise_pos:
845852
resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0]))
@@ -881,7 +888,7 @@ def forward(
881888
def arr2tensor(self, arr, bs):
882889
arr = np.transpose(arr, (2, 0, 1))
883890
_arr = torch.from_numpy(arr.copy()).float().cpu()
884-
if self.use_fp16:
891+
if self.config.use_fp16:
885892
_arr = _arr.half()
886893
_arr = torch.stack([_arr for _ in range(bs)], dim=0)
887894
return _arr
@@ -1021,12 +1028,10 @@ def insert_spaces(self, string, nSpace):
10211028
new_string += char + " " * nSpace
10221029
return new_string[:-nSpace]
10231030

1024-
def to(self, *args, **kwargs):
1025-
self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs)
1026-
self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
1027-
self.text_predictor = self.text_predictor.to(*args, **kwargs)
1028-
self.device = self.frozen_CLIP_embedder_t3.device
1029-
return self
1031+
# def to(self, *args, **kwargs):
1032+
# self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs)
1033+
# self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
1034+
# return self
10301035

10311036

10321037
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
@@ -1043,20 +1048,17 @@ def retrieve_latents(
10431048
raise AttributeError("Could not access latents of provided encoder_output")
10441049

10451050

1046-
class AuxiliaryLatentModule(nn.Module):
1051+
class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
1052+
@register_to_config
10471053
def __init__(
10481054
self,
1049-
font_path,
1050-
vae=None,
1055+
# font_path,
1056+
vae,
10511057
device="cpu",
1052-
use_fp16=False,
10531058
):
10541059
super().__init__()
1055-
self.font = ImageFont.truetype(font_path, 60)
1056-
self.use_fp16 = use_fp16
1057-
self.device = device
1058-
1059-
self.vae = vae.eval() if vae is not None else None
1060+
# self.font = ImageFont.truetype(font_path, 60)
1061+
# self.vae = vae.eval() if vae is not None else None
10601062

10611063
@torch.no_grad()
10621064
def forward(
@@ -1093,12 +1095,13 @@ def forward(
10931095
# get masked_x
10941096
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
10951097
masked_img = np.transpose(masked_img, (2, 0, 1))
1096-
device = next(self.vae.parameters()).device
1098+
device = next(self.config.vae.parameters()).device
1099+
dtype = next(self.config.vae.parameters()).dtype
10971100
masked_img = torch.from_numpy(masked_img.copy()).float().to(device)
1098-
if self.use_fp16:
1101+
if dtype == torch.float16:
10991102
masked_img = masked_img.half()
1100-
masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach()
1101-
if self.use_fp16:
1103+
masked_x = (retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.config.vae.config.scaling_factor).detach()
1104+
if dtype == torch.float16:
11021105
masked_x = masked_x.half()
11031106
text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0)
11041107

@@ -1137,10 +1140,10 @@ def insert_spaces(self, string, nSpace):
11371140
new_string += char + " " * nSpace
11381141
return new_string[:-nSpace]
11391142

1140-
def to(self, *args, **kwargs):
1141-
self.vae = self.vae.to(*args, **kwargs)
1142-
self.device = self.vae.device
1143-
return self
1143+
# def to(self, *args, **kwargs):
1144+
# self.vae = self.vae.to(*args, **kwargs)
1145+
# self.device = self.vae.device
1146+
# return self
11441147

11451148

11461149
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
@@ -1255,7 +1258,6 @@ class AnyTextPipeline(
12551258

12561259
def __init__(
12571260
self,
1258-
font_path: str,
12591261
vae: AutoencoderKL,
12601262
text_encoder: CLIPTextModel,
12611263
tokenizer: CLIPTokenizer,
@@ -1264,18 +1266,25 @@ def __init__(
12641266
scheduler: KarrasDiffusionSchedulers,
12651267
safety_checker: StableDiffusionSafetyChecker,
12661268
feature_extractor: CLIPImageProcessor,
1269+
font_path: str = None,
1270+
text_embedding_module: Optional[TextEmbeddingModule] = None,
1271+
auxiliary_latent_module: Optional[AuxiliaryLatentModule] = None,
12671272
trust_remote_code: bool = False,
1268-
text_embedding_module: TextEmbeddingModule = None,
1269-
auxiliary_latent_module: AuxiliaryLatentModule = None,
12701273
image_encoder: CLIPVisionModelWithProjection = None,
12711274
requires_safety_checker: bool = True,
12721275
):
12731276
super().__init__()
1274-
self.text_embedding_module = TextEmbeddingModule(
1275-
use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path
1277+
if font_path is None:
1278+
raise ValueError("font_path is required!")
1279+
1280+
text_embedding_module = TextEmbeddingModule(
1281+
font_path=font_path,
1282+
use_fp16=unet.dtype == torch.float16,
12761283
)
1277-
self.auxiliary_latent_module = AuxiliaryLatentModule(
1278-
vae=vae, use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path
1284+
auxiliary_latent_module = AuxiliaryLatentModule(
1285+
# font_path=font_path,
1286+
vae=vae,
1287+
# use_fp16=unet.dtype == torch.float16,
12791288
)
12801289

12811290
if safety_checker is None and requires_safety_checker:
@@ -1307,15 +1316,15 @@ def __init__(
13071316
safety_checker=safety_checker,
13081317
feature_extractor=feature_extractor,
13091318
image_encoder=image_encoder,
1310-
text_embedding_module=self.text_embedding_module,
1311-
auxiliary_latent_module=self.auxiliary_latent_module,
1319+
text_embedding_module=text_embedding_module,
1320+
auxiliary_latent_module=auxiliary_latent_module,
13121321
)
13131322
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
13141323
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
13151324
self.control_image_processor = VaeImageProcessor(
13161325
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
13171326
)
1318-
self.register_to_config(requires_safety_checker=requires_safety_checker, font_path=font_path)
1327+
self.register_to_config(requires_safety_checker=requires_safety_checker)#, font_path=font_path)
13191328

13201329
def modify_prompt(self, prompt):
13211330
prompt = prompt.replace("“", '"')
@@ -2331,7 +2340,7 @@ def __call__(
23312340
cond_scale = controlnet_cond_scale * controlnet_keep[i]
23322341

23332342
down_block_res_samples, mid_block_res_sample = self.controlnet(
2334-
control_model_input,
2343+
control_model_input.to(self.controlnet.dtype),
23352344
t,
23362345
encoder_hidden_states=controlnet_prompt_embeds,
23372346
controlnet_cond=guided_hint,

0 commit comments

Comments
 (0)