@@ -196,10 +196,10 @@ def forward(
196196 np_hint = np .sum (pre_pos , axis = 0 ).clip (0 , 1 )
197197 # prepare info dict
198198 info = {}
199- info [' glyphs' ] = []
200- info [' gly_line' ] = []
201- info [' positions' ] = []
202- info [' n_lines' ] = [len (texts )]* len (prompt )
199+ info [" glyphs" ] = []
200+ info [" gly_line" ] = []
201+ info [" positions" ] = []
202+ info [" n_lines" ] = [len (texts )] * len (prompt )
203203 for i in range (len (texts )):
204204 text = texts [i ]
205205 if len (text ) > max_chars :
@@ -209,40 +209,47 @@ def forward(
209209 gly_scale = 2
210210 if pre_pos [i ].mean () != 0 :
211211 gly_line = self .draw_glyph (self .font , text )
212- glyphs = self .draw_glyph2 (self .font , text , poly_list [i ], scale = gly_scale , width = w , height = h , add_space = False )
212+ glyphs = self .draw_glyph2 (
213+ self .font , text , poly_list [i ], scale = gly_scale , width = w , height = h , add_space = False
214+ )
213215 if revise_pos :
214216 resize_gly = cv2 .resize (glyphs , (pre_pos [i ].shape [1 ], pre_pos [i ].shape [0 ]))
215- new_pos = cv2 .morphologyEx ((resize_gly * 255 ).astype (np .uint8 ), cv2 .MORPH_CLOSE , kernel = np .ones ((resize_gly .shape [0 ]// 10 , resize_gly .shape [1 ]// 10 ), dtype = np .uint8 ), iterations = 1 )
217+ new_pos = cv2 .morphologyEx (
218+ (resize_gly * 255 ).astype (np .uint8 ),
219+ cv2 .MORPH_CLOSE ,
220+ kernel = np .ones ((resize_gly .shape [0 ] // 10 , resize_gly .shape [1 ] // 10 ), dtype = np .uint8 ),
221+ iterations = 1 ,
222+ )
216223 new_pos = new_pos [..., np .newaxis ] if len (new_pos .shape ) == 2 else new_pos
217224 contours , _ = cv2 .findContours (new_pos , cv2 .RETR_EXTERNAL , cv2 .CHAIN_APPROX_NONE )
218225 if len (contours ) != 1 :
219- str_warning = f' Fail to revise position { i } to bounding rect, remain position unchanged...'
226+ str_warning = f" Fail to revise position { i } to bounding rect, remain position unchanged..."
220227 logger .warning (str_warning )
221228 else :
222229 rect = cv2 .minAreaRect (contours [0 ])
223230 poly = np .int0 (cv2 .boxPoints (rect ))
224- pre_pos [i ] = cv2 .drawContours (new_pos , [poly ], - 1 , 255 , - 1 ) / 255.
231+ pre_pos [i ] = cv2 .drawContours (new_pos , [poly ], - 1 , 255 , - 1 ) / 255.0
225232 else :
226- glyphs = np .zeros ((h * gly_scale , w * gly_scale , 1 ))
233+ glyphs = np .zeros ((h * gly_scale , w * gly_scale , 1 ))
227234 gly_line = np .zeros ((80 , 512 , 1 ))
228235 pos = pre_pos [i ]
229- info [' glyphs' ] += [self .arr2tensor (glyphs , len (prompt ))]
230- info [' gly_line' ] += [self .arr2tensor (gly_line , len (prompt ))]
231- info [' positions' ] += [self .arr2tensor (pos , len (prompt ))]
236+ info [" glyphs" ] += [self .arr2tensor (glyphs , len (prompt ))]
237+ info [" gly_line" ] += [self .arr2tensor (gly_line , len (prompt ))]
238+ info [" positions" ] += [self .arr2tensor (pos , len (prompt ))]
232239 # get masked_x
233- masked_img = ((edit_image .astype (np .float32 ) / 127.5 ) - 1.0 )* ( 1 - np_hint )
240+ masked_img = ((edit_image .astype (np .float32 ) / 127.5 ) - 1.0 ) * ( 1 - np_hint )
234241 masked_img = np .transpose (masked_img , (2 , 0 , 1 ))
235242 masked_img = torch .from_numpy (masked_img .copy ()).float ().to (self .device )
236243 if self .use_fp16 :
237244 masked_img = masked_img .half ()
238245 masked_x = self .encode_first_stage (masked_img [None , ...]).detach ()
239246 if self .use_fp16 :
240247 masked_x = masked_x .half ()
241- info [' masked_x' ] = torch .cat ([masked_x for _ in range (len (prompt ))], dim = 0 )
248+ info [" masked_x" ] = torch .cat ([masked_x for _ in range (len (prompt ))], dim = 0 )
242249 hint = self .arr2tensor (np_hint , len (prompt ))
243250
244- glyphs = torch .cat (info [' glyphs' ], dim = 1 ).sum (dim = 1 , keepdim = True )
245- positions = torch .cat (info [' positions' ], dim = 1 ).sum (dim = 1 , keepdim = True )
251+ glyphs = torch .cat (info [" glyphs" ], dim = 1 ).sum (dim = 1 , keepdim = True )
252+ positions = torch .cat (info [" positions" ], dim = 1 ).sum (dim = 1 , keepdim = True )
246253 enc_glyph = self .glyph_block (glyphs , emb , context )
247254 enc_pos = self .position_block (positions , emb , context )
248255 guided_hint = self .fuse_block (torch .cat ([enc_glyph , enc_pos , masked_x ], dim = 1 ))
0 commit comments