1111from PIL import ImageFont
1212from recognizer import TextRecognizer , create_predictor
1313from torch import nn
14+ from torch .nn import functional as F
15+ import numpy as np
16+ import cv2
1417
1518from diffusers .utils import (
1619 logging ,
@@ -25,7 +28,7 @@ def __init__(self, use_fp16):
2528 super ().__init__ ()
2629 self .device = "cuda" if torch .cuda .is_available () else "cpu"
2730 # TODO: Learn if the recommended font file is free to use
28- self .font = ImageFont .truetype ("/home/x /Documents/gits/AnyText/font/Arial_Unicode.ttf" , 60 )
31+ self .font = ImageFont .truetype ("/home/cosmos /Documents/gits/AnyText/font/Arial_Unicode.ttf" , 60 )
2932 self .frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3 (device = self .device )
3033 self .embedding_manager_config = {
3134 "valid" : True ,
@@ -39,12 +42,12 @@ def __init__(self, use_fp16):
3942 # TODO: Understand the reason of param.requires_grad = True
4043 for param in self .embedding_manager .embedding_parameters ():
4144 param .requires_grad = True
42- rec_model_dir = "/home/x /Documents/gits/AnyText/ocr_weights/ppv3_rec.pth"
45+ rec_model_dir = "/home/cosmos /Documents/gits/AnyText/ocr_weights/ppv3_rec.pth"
4346 self .text_predictor = create_predictor (rec_model_dir ).eval ()
4447 args = {}
4548 args ["rec_image_shape" ] = "3, 48, 320"
4649 args ["rec_batch_num" ] = 6
47- args ["rec_char_dict_path" ] = "/home/x /Documents/gits/AnyText/ocr_weights/ppocr_keys_v1.txt"
50+ args ["rec_char_dict_path" ] = "/home/cosmos /Documents/gits/AnyText/ocr_weights/ppocr_keys_v1.txt"
4851 args ["use_fp16" ] = use_fp16
4952 self .cn_recognizer = TextRecognizer (args , self .text_predictor )
5053 for param in self .text_predictor .parameters ():
@@ -55,11 +58,140 @@ def __init__(self, use_fp16):
5558 def forward (
5659 self ,
5760 prompt ,
58- text_info ,
59- negative_prompt = None ,
60- prompt_embeds : Optional [torch .Tensor ] = None ,
61- negative_prompt_embeds : Optional [torch .Tensor ] = None ,
61+ texts ,
62+ negative_prompt ,
63+ num_images_per_prompt ,
64+ mode ,
65+ draw_pos ,
66+ ori_image ,
67+ max_chars = 77 ,
68+ revise_pos = False ,
69+ sort_priority = False ,
70+ h = 512 ,
71+ w = 512 ,
6272 ):
73+ if prompt is None and texts is None :
74+ raise ValueError ("Prompt or texts must be provided!" )
75+ n_lines = len (texts )
76+ if mode == "generate" :
77+ edit_image = np .ones ((h , w , 3 )) * 127.5 # empty mask image
78+ elif mode == "edit" :
79+ if draw_pos is None or ori_image is None :
80+ raise ValueError ("Reference image and position image are needed for text editing!" )
81+ if isinstance (ori_image , str ):
82+ ori_image = cv2 .imread (ori_image )[..., ::- 1 ]
83+ if ori_image is None :
84+ raise ValueError (f"Can't read ori_image image from { ori_image } !" )
85+ elif isinstance (ori_image , torch .Tensor ):
86+ ori_image = ori_image .cpu ().numpy ()
87+ else :
88+ if not isinstance (ori_image , np .ndarray ):
89+ raise ValueError (f"Unknown format of ori_image: { type (ori_image )} " )
90+ edit_image = ori_image .clip (1 , 255 ) # for mask reason
91+ edit_image = self .check_channels (edit_image )
92+ edit_image = self .resize_image (
93+ edit_image , max_length = 768
94+ ) # make w h multiple of 64, resize if w or h > max_length
95+ h , w = edit_image .shape [:2 ] # change h, w by input ref_img
96+ # preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
97+ if draw_pos is None :
98+ pos_imgs = np .zeros ((w , h , 1 ))
99+ if isinstance (draw_pos , str ):
100+ draw_pos = cv2 .imread (draw_pos )[..., ::- 1 ]
101+ if draw_pos is None :
102+ raise ValueError (f"Can't read draw_pos image from { draw_pos } !" )
103+ pos_imgs = 255 - draw_pos
104+ elif isinstance (draw_pos , torch .Tensor ):
105+ pos_imgs = draw_pos .cpu ().numpy ()
106+ else :
107+ if not isinstance (draw_pos , np .ndarray ):
108+ raise ValueError (f"Unknown format of draw_pos: { type (draw_pos )} " )
109+ if mode == "edit" :
110+ pos_imgs = cv2 .resize (pos_imgs , (w , h ))
111+ pos_imgs = pos_imgs [..., 0 :1 ]
112+ pos_imgs = cv2 .convertScaleAbs (pos_imgs )
113+ _ , pos_imgs = cv2 .threshold (pos_imgs , 254 , 255 , cv2 .THRESH_BINARY )
114+ # separate pos_imgs
115+ pos_imgs = self .separate_pos_imgs (pos_imgs , sort_priority )
116+ if len (pos_imgs ) == 0 :
117+ pos_imgs = [np .zeros ((h , w , 1 ))]
118+ if len (pos_imgs ) < n_lines :
119+ if n_lines == 1 and texts [0 ] == " " :
120+ pass # text-to-image without text
121+ else :
122+ raise ValueError (
123+ f"Found { len (pos_imgs )} positions that < needed { n_lines } from prompt, check and try again!"
124+ )
125+ elif len (pos_imgs ) > n_lines :
126+ str_warning = f"Warning: found { len (pos_imgs )} positions that > needed { n_lines } from prompt."
127+ logger .warning (str_warning )
128+ # get pre_pos, poly_list, hint that needed for anytext
129+ pre_pos = []
130+ poly_list = []
131+ for input_pos in pos_imgs :
132+ if input_pos .mean () != 0 :
133+ input_pos = input_pos [..., np .newaxis ] if len (input_pos .shape ) == 2 else input_pos
134+ poly , pos_img = self .find_polygon (input_pos )
135+ pre_pos += [pos_img / 255.0 ]
136+ poly_list += [poly ]
137+ else :
138+ pre_pos += [np .zeros ((h , w , 1 ))]
139+ poly_list += [None ]
140+ np_hint = np .sum (pre_pos , axis = 0 ).clip (0 , 1 )
141+ # prepare info dict
142+ text_info = {}
143+ text_info ["glyphs" ] = []
144+ text_info ["gly_line" ] = []
145+ text_info ["positions" ] = []
146+ text_info ["n_lines" ] = [len (texts )] * num_images_per_prompt
147+ for i in range (len (texts )):
148+ text = texts [i ]
149+ if len (text ) > max_chars :
150+ str_warning = f'"{ text } " length > max_chars: { max_chars } , will be cut off...'
151+ logger .warning (str_warning )
152+ text = text [:max_chars ]
153+ gly_scale = 2
154+ if pre_pos [i ].mean () != 0 :
155+ gly_line = self .draw_glyph (self .font , text )
156+ glyphs = self .draw_glyph2 (
157+ self .font , text , poly_list [i ], scale = gly_scale , width = w , height = h , add_space = False
158+ )
159+ if revise_pos :
160+ resize_gly = cv2 .resize (glyphs , (pre_pos [i ].shape [1 ], pre_pos [i ].shape [0 ]))
161+ new_pos = cv2 .morphologyEx (
162+ (resize_gly * 255 ).astype (np .uint8 ),
163+ cv2 .MORPH_CLOSE ,
164+ kernel = np .ones ((resize_gly .shape [0 ] // 10 , resize_gly .shape [1 ] // 10 ), dtype = np .uint8 ),
165+ iterations = 1 ,
166+ )
167+ new_pos = new_pos [..., np .newaxis ] if len (new_pos .shape ) == 2 else new_pos
168+ contours , _ = cv2 .findContours (new_pos , cv2 .RETR_EXTERNAL , cv2 .CHAIN_APPROX_NONE )
169+ if len (contours ) != 1 :
170+ str_warning = f"Fail to revise position { i } to bounding rect, remain position unchanged..."
171+ logger .warning (str_warning )
172+ else :
173+ rect = cv2 .minAreaRect (contours [0 ])
174+ poly = np .int0 (cv2 .boxPoints (rect ))
175+ pre_pos [i ] = cv2 .drawContours (new_pos , [poly ], - 1 , 255 , - 1 ) / 255.0
176+ else :
177+ glyphs = np .zeros ((h * gly_scale , w * gly_scale , 1 ))
178+ gly_line = np .zeros ((80 , 512 , 1 ))
179+ pos = pre_pos [i ]
180+ text_info ["glyphs" ] += [self .arr2tensor (glyphs , len (prompt ))]
181+ text_info ["gly_line" ] += [self .arr2tensor (gly_line , len (prompt ))]
182+ text_info ["positions" ] += [self .arr2tensor (pos , len (prompt ))]
183+ # get masked_x
184+ masked_img = ((edit_image .astype (np .float32 ) / 127.5 ) - 1.0 ) * (1 - np_hint )
185+ masked_img = np .transpose (masked_img , (2 , 0 , 1 ))
186+ masked_img = torch .from_numpy (masked_img .copy ()).float ().to (self .device )
187+ if self .use_fp16 :
188+ masked_img = masked_img .half ()
189+ masked_x = self .encode_first_stage (masked_img [None , ...]).detach ()
190+ if self .use_fp16 :
191+ masked_x = masked_x .half ()
192+ text_info ["masked_x" ] = torch .cat ([masked_x for _ in range (len (prompt ))], dim = 0 )
193+ # hint = self.arr2tensor(np_hint, len(prompt))
194+
63195 self .embedding_manager .encode_text (text_info )
64196 prompt_embeds = self .frozen_CLIP_embedder_t3 .encode ([prompt ], embedding_manager = self .embedding_manager )
65197
0 commit comments