Skip to content

Commit d5a6e5f

Browse files
committed
style
1 parent 13b7ecf commit d5a6e5f

File tree

2 files changed

+48
-81
lines changed

2 files changed

+48
-81
lines changed

examples/research_projects/anytext/anytext.py

Lines changed: 43 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@
4141
from skimage.transform._geometric import _umeyama as get_sym_mat
4242
from torch import nn
4343
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
44+
from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
4445

4546
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
47+
from diffusers.configuration_utils import ConfigMixin, register_to_config
4648
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
4749
from diffusers.loaders import (
4850
FromSingleFileMixin,
@@ -52,13 +54,12 @@
5254
)
5355
from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
5456
from diffusers.models.lora import adjust_lora_scale_text_encoder
57+
from diffusers.models.modeling_utils import ModelMixin
5558
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
5659
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
5760
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
5861
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
5962
from diffusers.schedulers import KarrasDiffusionSchedulers
60-
from diffusers.configuration_utils import register_to_config, ConfigMixin
61-
from diffusers.models.modeling_utils import ModelMixin
6263
from diffusers.utils import (
6364
USE_PEFT_BACKEND,
6465
deprecate,
@@ -154,21 +155,14 @@ def _is_whitespace(self, char):
154155
>>> # I chose a font file shared by an HF staff:
155156
>>> !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
156157
157-
>>> # load control net and stable diffusion v1-5
158158
>>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
159159
... variant="fp16",)
160160
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
161161
... controlnet=anytext_controlnet, torch_dtype=torch.float16,
162-
... trust_remote_code=True,
162+
... trust_remote_code=False, # One needs to give permission to run this pipeline's code
163163
... ).to("cuda")
164164
165165
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
166-
>>> # uncomment following line if PyTorch>=2.0 is not installed for memory optimization
167-
>>> #pipe.enable_xformers_memory_efficient_attention()
168-
169-
>>> # uncomment following line if you want to offload the model to CPU for memory optimization
170-
>>> # also remove the `.to("cuda")` part
171-
>>> #pipe.enable_model_cpu_offload()
172166
173167
>>> # generate image
174168
>>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
@@ -211,8 +205,8 @@ def __init__(
211205
embedder,
212206
placeholder_string="*",
213207
use_fp16=False,
214-
token_dim = 768,
215-
get_recog_emb = None,
208+
token_dim=768,
209+
get_recog_emb=None,
216210
):
217211
super().__init__()
218212
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
@@ -227,9 +221,7 @@ def __init__(
227221
if use_fp16:
228222
self.proj = self.proj.to(dtype=torch.float16)
229223

230-
# self.register_parameter("proj", proj)
231224
self.placeholder_token = get_token_for_string(placeholder_string)
232-
# self.register_config(placeholder_token=placeholder_token)
233225

234226
@torch.no_grad()
235227
def encode_text(self, text_info):
@@ -350,12 +342,19 @@ def create_predictor(model_lang="ch", device="cpu", use_fp16=False):
350342
n_class = 97
351343
else:
352344
raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
353-
rec_config = dict(
354-
in_channels=3,
355-
backbone=dict(type="MobileNetV1Enhance", scale=0.5, last_conv_stride=[1, 2], last_pool_type="avg"),
356-
neck=dict(type="SequenceEncoder", encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True),
357-
head=dict(type="CTCHead", fc_decay=0.00001, out_channels=n_class, return_feats=True),
358-
)
345+
rec_config = {
346+
"in_channels": 3,
347+
"backbone": {"type": "MobileNetV1Enhance", "scale": 0.5, "last_conv_stride": [1, 2], "last_pool_type": "avg"},
348+
"neck": {
349+
"type": "SequenceEncoder",
350+
"encoder_type": "svtr",
351+
"dims": 64,
352+
"depth": 2,
353+
"hidden_dims": 120,
354+
"use_guide": True,
355+
},
356+
"head": {"type": "CTCHead", "fc_decay": 0.00001, "out_channels": n_class, "return_feats": True},
357+
}
359358

360359
rec_model = RecModel(rec_config)
361360
state_dict = torch.load(model_dir, map_location=device)
@@ -521,12 +520,6 @@ def get_ctcloss(self, preds, gt_text, weight):
521520
return loss
522521

523522

524-
import torch
525-
from torch import nn
526-
from transformers import CLIPTextModel, CLIPTokenizer
527-
from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
528-
529-
530523
class AbstractEncoder(nn.Module):
531524
def __init__(self):
532525
super().__init__()
@@ -537,6 +530,7 @@ def encode(self, *args, **kwargs):
537530

538531
class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):
539532
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
533+
540534
@register_to_config
541535
def __init__(
542536
self,
@@ -548,11 +542,13 @@ def __init__(
548542
):
549543
super().__init__()
550544
self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
551-
self.transformer = CLIPTextModel.from_pretrained("tolgacangoz/anytext", subfolder="text_encoder",
552-
torch_dtype=torch.float16 if use_fp16 else torch.float32,
553-
variant="fp16" if use_fp16 else None)
554-
# self.device = device
555-
# self.max_length = max_length
545+
self.transformer = CLIPTextModel.from_pretrained(
546+
"tolgacangoz/anytext",
547+
subfolder="text_encoder",
548+
torch_dtype=torch.float16 if use_fp16 else torch.float32,
549+
variant="fp16" if use_fp16 else None,
550+
)
551+
556552
if freeze:
557553
self.freeze()
558554

@@ -731,37 +727,28 @@ def split_chunks(self, input_ids, chunk_size=75):
731727
tokens_list.append(remaining_group_pad)
732728
return tokens_list
733729

734-
# def to(self, *args, **kwargs):
735-
# self.transformer = self.transformer.to(*args, **kwargs)
736-
# self.device = self.transformer.device
737-
# return self
738-
739730

740731
class TextEmbeddingModule(ModelMixin, ConfigMixin):
741732
@register_to_config
742733
def __init__(self, font_path, use_fp16=False, device="cpu"):
743734
super().__init__()
744735
font = ImageFont.truetype(font_path, 60)
745736

746-
# self.use_fp16 = use_fp16
747-
# self.device = device
748737
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
749738
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
750739
self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval()
751-
args = {"rec_image_shape": "3, 48, 320",
752-
"rec_batch_num": 6,
753-
"rec_char_dict_path": hf_hub_download(
754-
repo_id="tolgacangoz/anytext",
755-
filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
756-
cache_dir=HF_MODULES_CACHE,
757-
),
758-
"use_fp16": use_fp16}
740+
args = {
741+
"rec_image_shape": "3, 48, 320",
742+
"rec_batch_num": 6,
743+
"rec_char_dict_path": hf_hub_download(
744+
repo_id="tolgacangoz/anytext",
745+
filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
746+
cache_dir=HF_MODULES_CACHE,
747+
),
748+
"use_fp16": use_fp16,
749+
}
759750
self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
760751

761-
# self.register_modules(
762-
# frozen_CLIP_embedder_t3=frozen_CLIP_embedder_t3,
763-
# embedding_manager=embedding_manager,
764-
# )
765752
self.register_to_config(font=font)
766753

767754
@torch.no_grad()
@@ -873,8 +860,6 @@ def forward(
873860
text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)]
874861
text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)]
875862

876-
# hint = self.arr2tensor(np_hint, len(prompt))
877-
878863
self.embedding_manager.encode_text(text_info)
879864
prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager)
880865

@@ -1028,11 +1013,6 @@ def insert_spaces(self, string, nSpace):
10281013
new_string += char + " " * nSpace
10291014
return new_string[:-nSpace]
10301015

1031-
# def to(self, *args, **kwargs):
1032-
# self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs)
1033-
# self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
1034-
# return self
1035-
10361016

10371017
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
10381018
def retrieve_latents(
@@ -1052,13 +1032,10 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
10521032
@register_to_config
10531033
def __init__(
10541034
self,
1055-
# font_path,
10561035
vae,
10571036
device="cpu",
10581037
):
10591038
super().__init__()
1060-
# self.font = ImageFont.truetype(font_path, 60)
1061-
# self.vae = vae.eval() if vae is not None else None
10621039

10631040
@torch.no_grad()
10641041
def forward(
@@ -1100,7 +1077,9 @@ def forward(
11001077
masked_img = torch.from_numpy(masked_img.copy()).float().to(device)
11011078
if dtype == torch.float16:
11021079
masked_img = masked_img.half()
1103-
masked_x = (retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.config.vae.config.scaling_factor).detach()
1080+
masked_x = (
1081+
retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.config.vae.config.scaling_factor
1082+
).detach()
11041083
if dtype == torch.float16:
11051084
masked_x = masked_x.half()
11061085
text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0)
@@ -1140,11 +1119,6 @@ def insert_spaces(self, string, nSpace):
11401119
new_string += char + " " * nSpace
11411120
return new_string[:-nSpace]
11421121

1143-
# def to(self, *args, **kwargs):
1144-
# self.vae = self.vae.to(*args, **kwargs)
1145-
# self.device = self.vae.device
1146-
# return self
1147-
11481122

11491123
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
11501124
def retrieve_timesteps(
@@ -1277,15 +1251,8 @@ def __init__(
12771251
if font_path is None:
12781252
raise ValueError("font_path is required!")
12791253

1280-
text_embedding_module = TextEmbeddingModule(
1281-
font_path=font_path,
1282-
use_fp16=unet.dtype == torch.float16,
1283-
)
1284-
auxiliary_latent_module = AuxiliaryLatentModule(
1285-
# font_path=font_path,
1286-
vae=vae,
1287-
# use_fp16=unet.dtype == torch.float16,
1288-
)
1254+
text_embedding_module = TextEmbeddingModule(font_path=font_path, use_fp16=unet.dtype == torch.float16)
1255+
auxiliary_latent_module = AuxiliaryLatentModule(vae=vae)
12891256

12901257
if safety_checker is None and requires_safety_checker:
12911258
logger.warning(
@@ -1324,7 +1291,7 @@ def __init__(
13241291
self.control_image_processor = VaeImageProcessor(
13251292
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
13261293
)
1327-
self.register_to_config(requires_safety_checker=requires_safety_checker)#, font_path=font_path)
1294+
self.register_to_config(requires_safety_checker=requires_safety_checker)
13281295

13291296
def modify_prompt(self, prompt):
13301297
prompt = prompt.replace("“", '"')

examples/research_projects/anytext/ocr_recog/RecModel.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ def __init__(self, config):
1616
assert "in_channels" in config, "in_channels must in model config"
1717
backbone_type = config["backbone"].pop("type")
1818
assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}"
19-
self.backbone = backbone_dict[backbone_type](config['in_channels'], **config['backbone'])
19+
self.backbone = backbone_dict[backbone_type](config["in_channels"], **config["backbone"])
2020

21-
neck_type = config['neck'].pop("type")
21+
neck_type = config["neck"].pop("type")
2222
assert neck_type in neck_dict, f"neck.type must in {neck_dict}"
23-
self.neck = neck_dict[neck_type](self.backbone.out_channels, **config['neck'])
23+
self.neck = neck_dict[neck_type](self.backbone.out_channels, **config["neck"])
2424

25-
head_type = config['head'].pop("type")
25+
head_type = config["head"].pop("type")
2626
assert head_type in head_dict, f"head.type must in {head_dict}"
27-
self.head = head_dict[head_type](self.neck.out_channels, **config['head'])
27+
self.head = head_dict[head_type](self.neck.out_channels, **config["head"])
2828

2929
self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}"
3030

0 commit comments

Comments
 (0)