3131
3232
3333class TextEmbeddingModule (nn .Module ):
34- def __init__ (self , font_path , device , use_fp16 ):
34+ def __init__ (self , use_fp16 ):
3535 super ().__init__ ()
36- self .device = device
36+ self .device = "cuda" if torch . cuda . is_available () else "cpu"
3737 # TODO: Learn if the recommended font file is free to use
38- self .font = ImageFont .truetype (font_path , 60 )
38+ self .font = ImageFont .truetype ("./font/Arial_Unicode.ttf" , 60 )
3939 self .frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3 (device = self .device )
4040 self .embedding_manager_config = {
4141 "valid" : True ,
@@ -49,12 +49,12 @@ def __init__(self, font_path, device, use_fp16):
4949 # TODO: Understand the reason of param.requires_grad = True
5050 for param in self .embedding_manager .embedding_parameters ():
5151 param .requires_grad = True
52- rec_model_dir = "./ocr_weights /ppv3_rec.pth"
52+ rec_model_dir = "./ocr /ppv3_rec.pth"
5353 self .text_predictor = create_predictor (rec_model_dir ).eval ()
5454 args = {}
5555 args ["rec_image_shape" ] = "3, 48, 320"
5656 args ["rec_batch_num" ] = 6
57- args ["rec_char_dict_path" ] = "./ocr_recog /ppocr_keys_v1.txt"
57+ args ["rec_char_dict_path" ] = "./ocr /ppocr_keys_v1.txt"
5858 args ["use_fp16" ] = use_fp16
5959 self .cn_recognizer = TextRecognizer (args , self .text_predictor )
6060 for param in self .text_predictor .parameters ():
@@ -65,185 +65,15 @@ def __init__(self, font_path, device, use_fp16):
6565 def forward (
6666 self ,
6767 prompt ,
68- device ,
69- num_images_per_prompt ,
70- do_classifier_free_guidance ,
71- hint ,
7268 text_info ,
7369 negative_prompt = None ,
7470 prompt_embeds : Optional [torch .Tensor ] = None ,
7571 negative_prompt_embeds : Optional [torch .Tensor ] = None ,
76- lora_scale : Optional [float ] = None ,
77- clip_skip : Optional [int ] = None ,
7872 ):
79- # TODO: Convert `get_learned_conditioning` functions to `diffusers`' format
80- prompt_embeds = self .get_learned_conditioning (
81- {"c_concat" : [hint ], "c_crossattn" : [[prompt ] * num_images_per_prompt ], "text_info" : text_info }
82- )
83- negative_prompt_embeds = self .get_learned_conditioning (
84- {"c_concat" : [hint ], "c_crossattn" : [[negative_prompt ] * num_images_per_prompt ], "text_info" : text_info }
85- )
73+ self .embedding_manager .encode_text (text_info )
74+ prompt_embeds = self .frozen_CLIP_embedder_t3 .encode ([prompt ], embedding_manager = self .embedding_manager )
8675
87- # set lora scale so that monkey patched LoRA
88- # function of text encoder can correctly access it
89- if lora_scale is not None and isinstance (self , StableDiffusionLoraLoaderMixin ):
90- self ._lora_scale = lora_scale
91-
92- # dynamically adjust the LoRA scale
93- if not USE_PEFT_BACKEND :
94- adjust_lora_scale_text_encoder (self .text_encoder , lora_scale )
95- else :
96- scale_lora_layers (self .text_encoder , lora_scale )
97-
98- if prompt is not None and isinstance (prompt , str ):
99- batch_size = 1
100- elif prompt is not None and isinstance (prompt , list ):
101- batch_size = len (prompt )
102- else :
103- batch_size = prompt_embeds .shape [0 ]
104-
105- if prompt_embeds is None :
106- # textual inversion: process multi-vector tokens if necessary
107- if isinstance (self , TextualInversionLoaderMixin ):
108- prompt = self .maybe_convert_prompt (prompt , self .tokenizer )
109-
110- text_inputs = self .tokenizer (
111- prompt ,
112- padding = "max_length" ,
113- max_length = self .tokenizer .model_max_length ,
114- truncation = True ,
115- return_tensors = "pt" ,
116- )
117- text_input_ids = text_inputs .input_ids
118- untruncated_ids = self .tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
119-
120- if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (
121- text_input_ids , untruncated_ids
122- ):
123- removed_text = self .tokenizer .batch_decode (
124- untruncated_ids [:, self .tokenizer .model_max_length - 1 : - 1 ]
125- )
126- logger .warning (
127- "The following part of your input was truncated because CLIP can only handle sequences up to"
128- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
129- )
130-
131- if hasattr (self .text_encoder .config , "use_attention_mask" ) and self .text_encoder .config .use_attention_mask :
132- attention_mask = text_inputs .attention_mask .to (device )
133- else :
134- attention_mask = None
135-
136- if clip_skip is None :
137- prompt_embeds = self .text_encoder (text_input_ids .to (device ), attention_mask = attention_mask )
138- prompt_embeds = prompt_embeds [0 ]
139- else :
140- prompt_embeds = self .text_encoder (
141- text_input_ids .to (device ), attention_mask = attention_mask , output_hidden_states = True
142- )
143- # Access the `hidden_states` first, that contains a tuple of
144- # all the hidden states from the encoder layers. Then index into
145- # the tuple to access the hidden states from the desired layer.
146- prompt_embeds = prompt_embeds [- 1 ][- (clip_skip + 1 )]
147- # We also need to apply the final LayerNorm here to not mess with the
148- # representations. The `last_hidden_states` that we typically use for
149- # obtaining the final prompt representations passes through the LayerNorm
150- # layer.
151- prompt_embeds = self .text_encoder .text_model .final_layer_norm (prompt_embeds )
152-
153- if self .text_encoder is not None :
154- prompt_embeds_dtype = self .text_encoder .dtype
155- elif self .unet is not None :
156- prompt_embeds_dtype = self .unet .dtype
157- else :
158- prompt_embeds_dtype = prompt_embeds .dtype
159-
160- prompt_embeds = prompt_embeds .to (dtype = prompt_embeds_dtype , device = device )
161-
162- bs_embed , seq_len , _ = prompt_embeds .shape
163- # duplicate text embeddings for each generation per prompt, using mps friendly method
164- prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
165- prompt_embeds = prompt_embeds .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
166-
167- # get unconditional embeddings for classifier free guidance
168- if do_classifier_free_guidance and negative_prompt_embeds is None :
169- uncond_tokens : List [str ]
170- if negative_prompt is None :
171- uncond_tokens = ["" ] * batch_size
172- elif prompt is not None and type (prompt ) is not type (negative_prompt ):
173- raise TypeError (
174- f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
175- f" { type (prompt )} ."
176- )
177- elif isinstance (negative_prompt , str ):
178- uncond_tokens = [negative_prompt ]
179- elif batch_size != len (negative_prompt ):
180- raise ValueError (
181- f"`negative_prompt`: { negative_prompt } has batch size { len (negative_prompt )} , but `prompt`:"
182- f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches"
183- " the batch size of `prompt`."
184- )
185- else :
186- uncond_tokens = negative_prompt
187-
188- # textual inversion: process multi-vector tokens if necessary
189- if isinstance (self , TextualInversionLoaderMixin ):
190- uncond_tokens = self .maybe_convert_prompt (uncond_tokens , self .tokenizer )
191-
192- max_length = prompt_embeds .shape [1 ]
193- uncond_input = self .tokenizer (
194- uncond_tokens ,
195- padding = "max_length" ,
196- max_length = max_length ,
197- truncation = True ,
198- return_tensors = "pt" ,
199- )
200-
201- if hasattr (self .text_encoder .config , "use_attention_mask" ) and self .text_encoder .config .use_attention_mask :
202- attention_mask = uncond_input .attention_mask .to (device )
203- else :
204- attention_mask = None
205-
206- negative_prompt_embeds = self .text_encoder (
207- uncond_input .input_ids .to (device ),
208- attention_mask = attention_mask ,
209- )
210- negative_prompt_embeds = negative_prompt_embeds [0 ]
211-
212- if do_classifier_free_guidance :
213- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
214- seq_len = negative_prompt_embeds .shape [1 ]
215-
216- negative_prompt_embeds = negative_prompt_embeds .to (dtype = prompt_embeds_dtype , device = device )
217-
218- negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
219- negative_prompt_embeds = negative_prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
220-
221- if self .text_encoder is not None :
222- if isinstance (self , StableDiffusionLoraLoaderMixin ) and USE_PEFT_BACKEND :
223- # Retrieve the original scale by scaling back the LoRA layers
224- unscale_lora_layers (self .text_encoder , lora_scale )
76+ self .embedding_manager .encode_text (text_info )
77+ negative_prompt_embeds = self .frozen_CLIP_embedder_t3 .encode ([negative_prompt ], embedding_manager = self .embedding_manager )
22578
22679 return prompt_embeds , negative_prompt_embeds
227-
228- def get_learned_conditioning (self , c ):
229- if hasattr (self .frozen_CLIP_embedder_t3 , "encode" ) and callable (self .frozen_CLIP_embedder_t3 .encode ):
230- if self .embedding_manager is not None and c ["text_info" ] is not None :
231- self .embedding_manager .encode_text (c ["text_info" ])
232- if isinstance (c , dict ):
233- cond_txt = c ["c_crossattn" ][0 ]
234- else :
235- cond_txt = c
236- if self .embedding_manager is not None :
237- cond_txt = self .frozen_CLIP_embedder_t3 .encode (cond_txt , embedding_manager = self .embedding_manager )
238- else :
239- cond_txt = self .frozen_CLIP_embedder_t3 .encode (cond_txt )
240- if isinstance (c , dict ):
241- c ["c_crossattn" ][0 ] = cond_txt
242- else :
243- c = cond_txt
244- if isinstance (c , DiagonalGaussianDistribution ):
245- c = c .mode ()
246- else :
247- c = self .frozen_CLIP_embedder_t3 (c )
248-
249- return c
0 commit comments