3434import PIL .Image
3535import torch
3636import torch .nn .functional as F
37- from easydict import EasyDict as edict
3837from huggingface_hub import hf_hub_download
3938from ocr_recog .RecModel import RecModel
4039from PIL import Image , ImageDraw , ImageFont
5857from diffusers .pipelines .stable_diffusion .pipeline_output import StableDiffusionPipelineOutput
5958from diffusers .pipelines .stable_diffusion .safety_checker import StableDiffusionSafetyChecker
6059from diffusers .schedulers import KarrasDiffusionSchedulers
60+ from diffusers .configuration_utils import register_to_config , ConfigMixin
61+ from diffusers .models .modeling_utils import ModelMixin
6162from diffusers .utils import (
6263 USE_PEFT_BACKEND ,
6364 deprecate ,
@@ -203,18 +204,18 @@ def get_recog_emb(encoder, img_list):
203204 return preds_neck
204205
205206
206- class EmbeddingManager (nn .Module ):
207+ class EmbeddingManager (ModelMixin , ConfigMixin ):
208+ @register_to_config
207209 def __init__ (
208210 self ,
209211 embedder ,
210212 placeholder_string = "*" ,
211213 use_fp16 = False ,
214+ token_dim = 768 ,
215+ get_recog_emb = None ,
212216 ):
213217 super ().__init__ ()
214218 get_token_for_string = partial (get_clip_token_for_string , embedder .tokenizer )
215- token_dim = 768
216- self .get_recog_emb = None
217- self .token_dim = token_dim
218219
219220 self .proj = nn .Linear (40 * 64 , token_dim )
220221 proj_dir = hf_hub_download (
@@ -226,12 +227,14 @@ def __init__(
226227 if use_fp16 :
227228 self .proj = self .proj .to (dtype = torch .float16 )
228229
230+ # self.register_parameter("proj", proj)
229231 self .placeholder_token = get_token_for_string (placeholder_string )
232+ # self.register_config(placeholder_token=placeholder_token)
230233
231234 @torch .no_grad ()
232235 def encode_text (self , text_info ):
233- if self .get_recog_emb is None :
234- self .get_recog_emb = partial (get_recog_emb , self .recog )
236+ if self .config . get_recog_emb is None :
237+ self .config . get_recog_emb = partial (get_recog_emb , self .recog )
235238
236239 gline_list = []
237240 for i in range (len (text_info ["n_lines" ])): # sample index in a batch
@@ -240,7 +243,7 @@ def encode_text(self, text_info):
240243 gline_list += [text_info ["gly_line" ][j ][i : i + 1 ]]
241244
242245 if len (gline_list ) > 0 :
243- recog_emb = self .get_recog_emb (gline_list )
246+ recog_emb = self .config . get_recog_emb (gline_list )
244247 enc_glyph = self .proj (recog_emb .reshape (recog_emb .shape [0 ], - 1 ).to (self .proj .weight .dtype ))
245248
246249 self .text_embs_all = []
@@ -332,13 +335,12 @@ def crop_image(src_img, mask):
332335 return result
333336
334337
335- def create_predictor (model_dir = None , model_lang = "ch" , device = "cpu" , use_fp16 = False ):
336- if model_dir is None or not os .path .exists (model_dir ):
337- model_dir = hf_hub_download (
338- repo_id = "tolgacangoz/anytext" ,
339- filename = "text_embedding_module/OCR/ppv3_rec.pth" ,
340- cache_dir = HF_MODULES_CACHE ,
341- )
338+ def create_predictor (model_lang = "ch" , device = "cpu" , use_fp16 = False ):
339+ model_dir = hf_hub_download (
340+ repo_id = "tolgacangoz/anytext" ,
341+ filename = "text_embedding_module/OCR/ppv3_rec.pth" ,
342+ cache_dir = HF_MODULES_CACHE ,
343+ )
342344 if not os .path .exists (model_dir ):
343345 raise ValueError ("not find model file path {}" .format (model_dir ))
344346
@@ -533,24 +535,24 @@ def encode(self, *args, **kwargs):
533535 raise NotImplementedError
534536
535537
536- class FrozenCLIPEmbedderT3 (AbstractEncoder ):
538+ class FrozenCLIPEmbedderT3 (AbstractEncoder , ModelMixin , ConfigMixin ):
537539 """Uses the CLIP transformer encoder for text (from Hugging Face)"""
538-
540+ @ register_to_config
539541 def __init__ (
540542 self ,
541- version = "openai/clip-vit-large-patch14" ,
542543 device = "cpu" ,
543544 max_length = 77 ,
544545 freeze = True ,
545546 use_fp16 = False ,
547+ variant : Optional [str ] = None ,
546548 ):
547549 super ().__init__ ()
548- self .tokenizer = CLIPTokenizer .from_pretrained (version )
549- self .transformer = CLIPTextModel .from_pretrained (
550- version , use_safetensors = True , torch_dtype = torch .float16 if use_fp16 else torch .float32
551- ). to ( device )
552- self .device = device
553- self .max_length = max_length
550+ self .tokenizer = CLIPTokenizer .from_pretrained ("tolgacangoz/anytext" , subfolder = "tokenizer" )
551+ self .transformer = CLIPTextModel .from_pretrained ("tolgacangoz/anytext" , subfolder = "text_encoder" ,
552+ torch_dtype = torch .float16 if use_fp16 else torch .float32 ,
553+ variant = "fp16" if use_fp16 else None )
554+ # self.device = device
555+ # self.max_length = max_length
554556 if freeze :
555557 self .freeze ()
556558
@@ -686,7 +688,7 @@ def forward(self, text, **kwargs):
686688 batch_encoding = self .tokenizer (
687689 text ,
688690 truncation = False ,
689- max_length = self .max_length ,
691+ max_length = self .config . max_length ,
690692 return_length = True ,
691693 return_overflowing_tokens = False ,
692694 padding = "longest" ,
@@ -729,34 +731,39 @@ def split_chunks(self, input_ids, chunk_size=75):
729731 tokens_list .append (remaining_group_pad )
730732 return tokens_list
731733
732- def to (self , * args , ** kwargs ):
733- self .transformer = self .transformer .to (* args , ** kwargs )
734- self .device = self .transformer .device
735- return self
734+ # def to(self, *args, **kwargs):
735+ # self.transformer = self.transformer.to(*args, **kwargs)
736+ # self.device = self.transformer.device
737+ # return self
736738
737739
738- class TextEmbeddingModule (nn .Module ):
740+ class TextEmbeddingModule (ModelMixin , ConfigMixin ):
741+ @register_to_config
739742 def __init__ (self , font_path , use_fp16 = False , device = "cpu" ):
740743 super ().__init__ ()
741- self .font = ImageFont .truetype (font_path , 60 )
742- self .use_fp16 = use_fp16
743- self .device = device
744+ font = ImageFont .truetype (font_path , 60 )
745+
746+ # self.use_fp16 = use_fp16
747+ # self.device = device
744748 self .frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3 (device = device , use_fp16 = use_fp16 )
745749 self .embedding_manager = EmbeddingManager (self .frozen_CLIP_embedder_t3 , use_fp16 = use_fp16 )
746- rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
747- self .text_predictor = create_predictor (rec_model_dir , device = device , use_fp16 = use_fp16 ).eval ()
748- args = {}
749- args ["rec_image_shape" ] = "3, 48, 320"
750- args ["rec_batch_num" ] = 6
751- args ["rec_char_dict_path" ] = "./text_embedding_module/OCR/ppocr_keys_v1.txt"
752- args ["rec_char_dict_path" ] = hf_hub_download (
753- repo_id = "tolgacangoz/anytext" ,
754- filename = "text_embedding_module/OCR/ppocr_keys_v1.txt" ,
755- cache_dir = HF_MODULES_CACHE ,
756- )
757- args ["use_fp16" ] = use_fp16
750+ self .text_predictor = create_predictor (device = device , use_fp16 = use_fp16 ).eval ()
751+ args = {"rec_image_shape" : "3, 48, 320" ,
752+ "rec_batch_num" : 6 ,
753+ "rec_char_dict_path" : hf_hub_download (
754+ repo_id = "tolgacangoz/anytext" ,
755+ filename = "text_embedding_module/OCR/ppocr_keys_v1.txt" ,
756+ cache_dir = HF_MODULES_CACHE ,
757+ ),
758+ "use_fp16" : use_fp16 }
758759 self .embedding_manager .recog = TextRecognizer (args , self .text_predictor )
759760
761+ # self.register_modules(
762+ # frozen_CLIP_embedder_t3=frozen_CLIP_embedder_t3,
763+ # embedding_manager=embedding_manager,
764+ # )
765+ self .register_to_config (font = font )
766+
760767 @torch .no_grad ()
761768 def forward (
762769 self ,
@@ -837,9 +844,9 @@ def forward(
837844 text = text [:max_chars ]
838845 gly_scale = 2
839846 if pre_pos [i ].mean () != 0 :
840- gly_line = self .draw_glyph (self .font , text )
847+ gly_line = self .draw_glyph (self .config . font , text )
841848 glyphs = self .draw_glyph2 (
842- self .font , text , poly_list [i ], scale = gly_scale , width = w , height = h , add_space = False
849+ self .config . font , text , poly_list [i ], scale = gly_scale , width = w , height = h , add_space = False
843850 )
844851 if revise_pos :
845852 resize_gly = cv2 .resize (glyphs , (pre_pos [i ].shape [1 ], pre_pos [i ].shape [0 ]))
@@ -881,7 +888,7 @@ def forward(
881888 def arr2tensor (self , arr , bs ):
882889 arr = np .transpose (arr , (2 , 0 , 1 ))
883890 _arr = torch .from_numpy (arr .copy ()).float ().cpu ()
884- if self .use_fp16 :
891+ if self .config . use_fp16 :
885892 _arr = _arr .half ()
886893 _arr = torch .stack ([_arr for _ in range (bs )], dim = 0 )
887894 return _arr
@@ -1021,12 +1028,10 @@ def insert_spaces(self, string, nSpace):
10211028 new_string += char + " " * nSpace
10221029 return new_string [:- nSpace ]
10231030
1024- def to (self , * args , ** kwargs ):
1025- self .frozen_CLIP_embedder_t3 = self .frozen_CLIP_embedder_t3 .to (* args , ** kwargs )
1026- self .embedding_manager = self .embedding_manager .to (* args , ** kwargs )
1027- self .text_predictor = self .text_predictor .to (* args , ** kwargs )
1028- self .device = self .frozen_CLIP_embedder_t3 .device
1029- return self
1031+ # def to(self, *args, **kwargs):
1032+ # self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs)
1033+ # self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
1034+ # return self
10301035
10311036
10321037# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
@@ -1043,20 +1048,17 @@ def retrieve_latents(
10431048 raise AttributeError ("Could not access latents of provided encoder_output" )
10441049
10451050
1046- class AuxiliaryLatentModule (nn .Module ):
1051+ class AuxiliaryLatentModule (ModelMixin , ConfigMixin ):
1052+ @register_to_config
10471053 def __init__ (
10481054 self ,
1049- font_path ,
1050- vae = None ,
1055+ # font_path,
1056+ vae ,
10511057 device = "cpu" ,
1052- use_fp16 = False ,
10531058 ):
10541059 super ().__init__ ()
1055- self .font = ImageFont .truetype (font_path , 60 )
1056- self .use_fp16 = use_fp16
1057- self .device = device
1058-
1059- self .vae = vae .eval () if vae is not None else None
1060+ # self.font = ImageFont.truetype(font_path, 60)
1061+ # self.vae = vae.eval() if vae is not None else None
10601062
10611063 @torch .no_grad ()
10621064 def forward (
@@ -1093,12 +1095,13 @@ def forward(
10931095 # get masked_x
10941096 masked_img = ((edit_image .astype (np .float32 ) / 127.5 ) - 1.0 ) * (1 - np_hint )
10951097 masked_img = np .transpose (masked_img , (2 , 0 , 1 ))
1096- device = next (self .vae .parameters ()).device
1098+ device = next (self .config .vae .parameters ()).device
1099+ dtype = next (self .config .vae .parameters ()).dtype
10971100 masked_img = torch .from_numpy (masked_img .copy ()).float ().to (device )
1098- if self . use_fp16 :
1101+ if dtype == torch . float16 :
10991102 masked_img = masked_img .half ()
1100- masked_x = (retrieve_latents (self .vae .encode (masked_img [None , ...])) * self .vae .config .scaling_factor ).detach ()
1101- if self . use_fp16 :
1103+ masked_x = (retrieve_latents (self .config . vae .encode (masked_img [None , ...])) * self . config .vae .config .scaling_factor ).detach ()
1104+ if dtype == torch . float16 :
11021105 masked_x = masked_x .half ()
11031106 text_info ["masked_x" ] = torch .cat ([masked_x for _ in range (num_images_per_prompt )], dim = 0 )
11041107
@@ -1137,10 +1140,10 @@ def insert_spaces(self, string, nSpace):
11371140 new_string += char + " " * nSpace
11381141 return new_string [:- nSpace ]
11391142
1140- def to (self , * args , ** kwargs ):
1141- self .vae = self .vae .to (* args , ** kwargs )
1142- self .device = self .vae .device
1143- return self
1143+ # def to(self, *args, **kwargs):
1144+ # self.vae = self.vae.to(*args, **kwargs)
1145+ # self.device = self.vae.device
1146+ # return self
11441147
11451148
11461149# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
@@ -1255,7 +1258,6 @@ class AnyTextPipeline(
12551258
12561259 def __init__ (
12571260 self ,
1258- font_path : str ,
12591261 vae : AutoencoderKL ,
12601262 text_encoder : CLIPTextModel ,
12611263 tokenizer : CLIPTokenizer ,
@@ -1264,18 +1266,25 @@ def __init__(
12641266 scheduler : KarrasDiffusionSchedulers ,
12651267 safety_checker : StableDiffusionSafetyChecker ,
12661268 feature_extractor : CLIPImageProcessor ,
1269+ font_path : str = None ,
1270+ text_embedding_module : Optional [TextEmbeddingModule ] = None ,
1271+ auxiliary_latent_module : Optional [AuxiliaryLatentModule ] = None ,
12671272 trust_remote_code : bool = False ,
1268- text_embedding_module : TextEmbeddingModule = None ,
1269- auxiliary_latent_module : AuxiliaryLatentModule = None ,
12701273 image_encoder : CLIPVisionModelWithProjection = None ,
12711274 requires_safety_checker : bool = True ,
12721275 ):
12731276 super ().__init__ ()
1274- self .text_embedding_module = TextEmbeddingModule (
1275- use_fp16 = unet .dtype == torch .float16 , device = unet .device , font_path = font_path
1277+ if font_path is None :
1278+ raise ValueError ("font_path is required!" )
1279+
1280+ text_embedding_module = TextEmbeddingModule (
1281+ font_path = font_path ,
1282+ use_fp16 = unet .dtype == torch .float16 ,
12761283 )
1277- self .auxiliary_latent_module = AuxiliaryLatentModule (
1278- vae = vae , use_fp16 = unet .dtype == torch .float16 , device = unet .device , font_path = font_path
1284+ auxiliary_latent_module = AuxiliaryLatentModule (
1285+ # font_path=font_path,
1286+ vae = vae ,
1287+ # use_fp16=unet.dtype == torch.float16,
12791288 )
12801289
12811290 if safety_checker is None and requires_safety_checker :
@@ -1307,15 +1316,15 @@ def __init__(
13071316 safety_checker = safety_checker ,
13081317 feature_extractor = feature_extractor ,
13091318 image_encoder = image_encoder ,
1310- text_embedding_module = self . text_embedding_module ,
1311- auxiliary_latent_module = self . auxiliary_latent_module ,
1319+ text_embedding_module = text_embedding_module ,
1320+ auxiliary_latent_module = auxiliary_latent_module ,
13121321 )
13131322 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
13141323 self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor , do_convert_rgb = True )
13151324 self .control_image_processor = VaeImageProcessor (
13161325 vae_scale_factor = self .vae_scale_factor , do_convert_rgb = True , do_normalize = False
13171326 )
1318- self .register_to_config (requires_safety_checker = requires_safety_checker , font_path = font_path )
1327+ self .register_to_config (requires_safety_checker = requires_safety_checker ) # , font_path=font_path)
13191328
13201329 def modify_prompt (self , prompt ):
13211330 prompt = prompt .replace ("“" , '"' )
@@ -2331,7 +2340,7 @@ def __call__(
23312340 cond_scale = controlnet_cond_scale * controlnet_keep [i ]
23322341
23332342 down_block_res_samples , mid_block_res_sample = self .controlnet (
2334- control_model_input ,
2343+ control_model_input . to ( self . controlnet . dtype ) ,
23352344 t ,
23362345 encoder_hidden_states = controlnet_prompt_embeds ,
23372346 controlnet_cond = guided_hint ,
0 commit comments