4141from skimage .transform ._geometric import _umeyama as get_sym_mat
4242from torch import nn
4343from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer , CLIPVisionModelWithProjection
44+ from transformers .modeling_attn_mask_utils import _create_4d_causal_attention_mask , _prepare_4d_attention_mask
4445
4546from diffusers .callbacks import MultiPipelineCallbacks , PipelineCallback
47+ from diffusers .configuration_utils import ConfigMixin , register_to_config
4648from diffusers .image_processor import PipelineImageInput , VaeImageProcessor
4749from diffusers .loaders import (
4850 FromSingleFileMixin ,
5254)
5355from diffusers .models import AutoencoderKL , ControlNetModel , ImageProjection , UNet2DConditionModel
5456from diffusers .models .lora import adjust_lora_scale_text_encoder
57+ from diffusers .models .modeling_utils import ModelMixin
5558from diffusers .pipelines .controlnet .multicontrolnet import MultiControlNetModel
5659from diffusers .pipelines .pipeline_utils import DiffusionPipeline , StableDiffusionMixin
5760from diffusers .pipelines .stable_diffusion .pipeline_output import StableDiffusionPipelineOutput
5861from diffusers .pipelines .stable_diffusion .safety_checker import StableDiffusionSafetyChecker
5962from diffusers .schedulers import KarrasDiffusionSchedulers
60- from diffusers .configuration_utils import register_to_config , ConfigMixin
61- from diffusers .models .modeling_utils import ModelMixin
6263from diffusers .utils import (
6364 USE_PEFT_BACKEND ,
6465 deprecate ,
@@ -154,21 +155,14 @@ def _is_whitespace(self, char):
154155 >>> # I chose a font file shared by an HF staff:
155156 >>> !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
156157
157- >>> # load control net and stable diffusion v1-5
158158 >>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
159159 ... variant="fp16",)
160160 >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
161161 ... controlnet=anytext_controlnet, torch_dtype=torch.float16,
162- ... trust_remote_code=True,
162+ ... trust_remote_code=False, # One needs to give permission to run this pipeline's code
163163 ... ).to("cuda")
164164
165165 >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
166- >>> # uncomment following line if PyTorch>=2.0 is not installed for memory optimization
167- >>> #pipe.enable_xformers_memory_efficient_attention()
168-
169- >>> # uncomment following line if you want to offload the model to CPU for memory optimization
170- >>> # also remove the `.to("cuda")` part
171- >>> #pipe.enable_model_cpu_offload()
172166
173167 >>> # generate image
174168 >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
@@ -211,8 +205,8 @@ def __init__(
211205 embedder ,
212206 placeholder_string = "*" ,
213207 use_fp16 = False ,
214- token_dim = 768 ,
215- get_recog_emb = None ,
208+ token_dim = 768 ,
209+ get_recog_emb = None ,
216210 ):
217211 super ().__init__ ()
218212 get_token_for_string = partial (get_clip_token_for_string , embedder .tokenizer )
@@ -227,9 +221,7 @@ def __init__(
227221 if use_fp16 :
228222 self .proj = self .proj .to (dtype = torch .float16 )
229223
230- # self.register_parameter("proj", proj)
231224 self .placeholder_token = get_token_for_string (placeholder_string )
232- # self.register_config(placeholder_token=placeholder_token)
233225
234226 @torch .no_grad ()
235227 def encode_text (self , text_info ):
@@ -350,12 +342,19 @@ def create_predictor(model_lang="ch", device="cpu", use_fp16=False):
350342 n_class = 97
351343 else :
352344 raise ValueError (f"Unsupported OCR recog model_lang: { model_lang } " )
353- rec_config = dict (
354- in_channels = 3 ,
355- backbone = dict (type = "MobileNetV1Enhance" , scale = 0.5 , last_conv_stride = [1 , 2 ], last_pool_type = "avg" ),
356- neck = dict (type = "SequenceEncoder" , encoder_type = "svtr" , dims = 64 , depth = 2 , hidden_dims = 120 , use_guide = True ),
357- head = dict (type = "CTCHead" , fc_decay = 0.00001 , out_channels = n_class , return_feats = True ),
358- )
345+ rec_config = {
346+ "in_channels" : 3 ,
347+ "backbone" : {"type" : "MobileNetV1Enhance" , "scale" : 0.5 , "last_conv_stride" : [1 , 2 ], "last_pool_type" : "avg" },
348+ "neck" : {
349+ "type" : "SequenceEncoder" ,
350+ "encoder_type" : "svtr" ,
351+ "dims" : 64 ,
352+ "depth" : 2 ,
353+ "hidden_dims" : 120 ,
354+ "use_guide" : True ,
355+ },
356+ "head" : {"type" : "CTCHead" , "fc_decay" : 0.00001 , "out_channels" : n_class , "return_feats" : True },
357+ }
359358
360359 rec_model = RecModel (rec_config )
361360 state_dict = torch .load (model_dir , map_location = device )
@@ -521,12 +520,6 @@ def get_ctcloss(self, preds, gt_text, weight):
521520 return loss
522521
523522
524- import torch
525- from torch import nn
526- from transformers import CLIPTextModel , CLIPTokenizer
527- from transformers .modeling_attn_mask_utils import _create_4d_causal_attention_mask , _prepare_4d_attention_mask
528-
529-
530523class AbstractEncoder (nn .Module ):
531524 def __init__ (self ):
532525 super ().__init__ ()
@@ -537,6 +530,7 @@ def encode(self, *args, **kwargs):
537530
538531class FrozenCLIPEmbedderT3 (AbstractEncoder , ModelMixin , ConfigMixin ):
539532 """Uses the CLIP transformer encoder for text (from Hugging Face)"""
533+
540534 @register_to_config
541535 def __init__ (
542536 self ,
@@ -548,11 +542,13 @@ def __init__(
548542 ):
549543 super ().__init__ ()
550544 self .tokenizer = CLIPTokenizer .from_pretrained ("tolgacangoz/anytext" , subfolder = "tokenizer" )
551- self .transformer = CLIPTextModel .from_pretrained ("tolgacangoz/anytext" , subfolder = "text_encoder" ,
552- torch_dtype = torch .float16 if use_fp16 else torch .float32 ,
553- variant = "fp16" if use_fp16 else None )
554- # self.device = device
555- # self.max_length = max_length
545+ self .transformer = CLIPTextModel .from_pretrained (
546+ "tolgacangoz/anytext" ,
547+ subfolder = "text_encoder" ,
548+ torch_dtype = torch .float16 if use_fp16 else torch .float32 ,
549+ variant = "fp16" if use_fp16 else None ,
550+ )
551+
556552 if freeze :
557553 self .freeze ()
558554
@@ -731,37 +727,28 @@ def split_chunks(self, input_ids, chunk_size=75):
731727 tokens_list .append (remaining_group_pad )
732728 return tokens_list
733729
734- # def to(self, *args, **kwargs):
735- # self.transformer = self.transformer.to(*args, **kwargs)
736- # self.device = self.transformer.device
737- # return self
738-
739730
740731class TextEmbeddingModule (ModelMixin , ConfigMixin ):
741732 @register_to_config
742733 def __init__ (self , font_path , use_fp16 = False , device = "cpu" ):
743734 super ().__init__ ()
744735 font = ImageFont .truetype (font_path , 60 )
745736
746- # self.use_fp16 = use_fp16
747- # self.device = device
748737 self .frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3 (device = device , use_fp16 = use_fp16 )
749738 self .embedding_manager = EmbeddingManager (self .frozen_CLIP_embedder_t3 , use_fp16 = use_fp16 )
750739 self .text_predictor = create_predictor (device = device , use_fp16 = use_fp16 ).eval ()
751- args = {"rec_image_shape" : "3, 48, 320" ,
752- "rec_batch_num" : 6 ,
753- "rec_char_dict_path" : hf_hub_download (
754- repo_id = "tolgacangoz/anytext" ,
755- filename = "text_embedding_module/OCR/ppocr_keys_v1.txt" ,
756- cache_dir = HF_MODULES_CACHE ,
757- ),
758- "use_fp16" : use_fp16 }
740+ args = {
741+ "rec_image_shape" : "3, 48, 320" ,
742+ "rec_batch_num" : 6 ,
743+ "rec_char_dict_path" : hf_hub_download (
744+ repo_id = "tolgacangoz/anytext" ,
745+ filename = "text_embedding_module/OCR/ppocr_keys_v1.txt" ,
746+ cache_dir = HF_MODULES_CACHE ,
747+ ),
748+ "use_fp16" : use_fp16 ,
749+ }
759750 self .embedding_manager .recog = TextRecognizer (args , self .text_predictor )
760751
761- # self.register_modules(
762- # frozen_CLIP_embedder_t3=frozen_CLIP_embedder_t3,
763- # embedding_manager=embedding_manager,
764- # )
765752 self .register_to_config (font = font )
766753
767754 @torch .no_grad ()
@@ -873,8 +860,6 @@ def forward(
873860 text_info ["gly_line" ] += [self .arr2tensor (gly_line , num_images_per_prompt )]
874861 text_info ["positions" ] += [self .arr2tensor (pos , num_images_per_prompt )]
875862
876- # hint = self.arr2tensor(np_hint, len(prompt))
877-
878863 self .embedding_manager .encode_text (text_info )
879864 prompt_embeds = self .frozen_CLIP_embedder_t3 .encode ([prompt ], embedding_manager = self .embedding_manager )
880865
@@ -1028,11 +1013,6 @@ def insert_spaces(self, string, nSpace):
10281013 new_string += char + " " * nSpace
10291014 return new_string [:- nSpace ]
10301015
1031- # def to(self, *args, **kwargs):
1032- # self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs)
1033- # self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
1034- # return self
1035-
10361016
10371017# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
10381018def retrieve_latents (
@@ -1052,13 +1032,10 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
10521032 @register_to_config
10531033 def __init__ (
10541034 self ,
1055- # font_path,
10561035 vae ,
10571036 device = "cpu" ,
10581037 ):
10591038 super ().__init__ ()
1060- # self.font = ImageFont.truetype(font_path, 60)
1061- # self.vae = vae.eval() if vae is not None else None
10621039
10631040 @torch .no_grad ()
10641041 def forward (
@@ -1100,7 +1077,9 @@ def forward(
11001077 masked_img = torch .from_numpy (masked_img .copy ()).float ().to (device )
11011078 if dtype == torch .float16 :
11021079 masked_img = masked_img .half ()
1103- masked_x = (retrieve_latents (self .config .vae .encode (masked_img [None , ...])) * self .config .vae .config .scaling_factor ).detach ()
1080+ masked_x = (
1081+ retrieve_latents (self .config .vae .encode (masked_img [None , ...])) * self .config .vae .config .scaling_factor
1082+ ).detach ()
11041083 if dtype == torch .float16 :
11051084 masked_x = masked_x .half ()
11061085 text_info ["masked_x" ] = torch .cat ([masked_x for _ in range (num_images_per_prompt )], dim = 0 )
@@ -1140,11 +1119,6 @@ def insert_spaces(self, string, nSpace):
11401119 new_string += char + " " * nSpace
11411120 return new_string [:- nSpace ]
11421121
1143- # def to(self, *args, **kwargs):
1144- # self.vae = self.vae.to(*args, **kwargs)
1145- # self.device = self.vae.device
1146- # return self
1147-
11481122
11491123# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
11501124def retrieve_timesteps (
@@ -1277,15 +1251,8 @@ def __init__(
12771251 if font_path is None :
12781252 raise ValueError ("font_path is required!" )
12791253
1280- text_embedding_module = TextEmbeddingModule (
1281- font_path = font_path ,
1282- use_fp16 = unet .dtype == torch .float16 ,
1283- )
1284- auxiliary_latent_module = AuxiliaryLatentModule (
1285- # font_path=font_path,
1286- vae = vae ,
1287- # use_fp16=unet.dtype == torch.float16,
1288- )
1254+ text_embedding_module = TextEmbeddingModule (font_path = font_path , use_fp16 = unet .dtype == torch .float16 )
1255+ auxiliary_latent_module = AuxiliaryLatentModule (vae = vae )
12891256
12901257 if safety_checker is None and requires_safety_checker :
12911258 logger .warning (
@@ -1324,7 +1291,7 @@ def __init__(
13241291 self .control_image_processor = VaeImageProcessor (
13251292 vae_scale_factor = self .vae_scale_factor , do_convert_rgb = True , do_normalize = False
13261293 )
1327- self .register_to_config (requires_safety_checker = requires_safety_checker )#, font_path=font_path)
1294+ self .register_to_config (requires_safety_checker = requires_safety_checker )
13281295
13291296 def modify_prompt (self , prompt ):
13301297 prompt = prompt .replace ("“" , '"' )
0 commit comments