33# text -> tokenizer ->
44
55
6+ from typing import List , Optional
7+
68import torch
79from PIL import ImageFont
810from torch import nn
911
12+ from diffusers .loaders import (
13+ StableDiffusionLoraLoaderMixin ,
14+ TextualInversionLoaderMixin ,
15+ )
1016from diffusers .models .autoencoders .vae import DiagonalGaussianDistribution
11- from diffusers .utils import logging
17+ from diffusers .models .lora import adjust_lora_scale_text_encoder
18+ from diffusers .utils import (
19+ USE_PEFT_BACKEND ,
20+ logging ,
21+ scale_lora_layers ,
22+ unscale_lora_layers ,
23+ )
1224
1325from .embedding_manager import EmbeddingManager
1426from .frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3
@@ -50,14 +62,167 @@ def __init__(self, font_path, device, use_fp16):
5062 self .embedding_manager .recog = self .cn_recognizer
5163
5264 @torch .no_grad ()
53- def forward (self , prompt , device , num_images_per_prompt , do_classifier_free_guidance , hint , n_prompt , text_info ):
65+ def forward (
66+ self ,
67+ prompt ,
68+ device ,
69+ num_images_per_prompt ,
70+ do_classifier_free_guidance ,
71+ hint ,
72+ text_info ,
73+ negative_prompt = None ,
74+ prompt_embeds : Optional [torch .Tensor ] = None ,
75+ negative_prompt_embeds : Optional [torch .Tensor ] = None ,
76+ lora_scale : Optional [float ] = None ,
77+ clip_skip : Optional [int ] = None ,
78+ ):
79+ # TODO: Convert `get_learned_conditioning` functions to `diffusers`' format
5480 prompt_embeds = self .get_learned_conditioning (
55- {"c_concat" : [hint ], "c_crossattn" : [[prompt ] * len ( prompt ) ], "text_info" : text_info }
81+ {"c_concat" : [hint ], "c_crossattn" : [[prompt ] * num_images_per_prompt ], "text_info" : text_info }
5682 )
5783 negative_prompt_embeds = self .get_learned_conditioning (
58- {"c_concat" : [hint ], "c_crossattn" : [[n_prompt ] * len ( prompt ) ], "text_info" : text_info }
84+ {"c_concat" : [hint ], "c_crossattn" : [[negative_prompt ] * num_images_per_prompt ], "text_info" : text_info }
5985 )
6086
87+ # set lora scale so that monkey patched LoRA
88+ # function of text encoder can correctly access it
89+ if lora_scale is not None and isinstance (self , StableDiffusionLoraLoaderMixin ):
90+ self ._lora_scale = lora_scale
91+
92+ # dynamically adjust the LoRA scale
93+ if not USE_PEFT_BACKEND :
94+ adjust_lora_scale_text_encoder (self .text_encoder , lora_scale )
95+ else :
96+ scale_lora_layers (self .text_encoder , lora_scale )
97+
98+ if prompt is not None and isinstance (prompt , str ):
99+ batch_size = 1
100+ elif prompt is not None and isinstance (prompt , list ):
101+ batch_size = len (prompt )
102+ else :
103+ batch_size = prompt_embeds .shape [0 ]
104+
105+ if prompt_embeds is None :
106+ # textual inversion: process multi-vector tokens if necessary
107+ if isinstance (self , TextualInversionLoaderMixin ):
108+ prompt = self .maybe_convert_prompt (prompt , self .tokenizer )
109+
110+ text_inputs = self .tokenizer (
111+ prompt ,
112+ padding = "max_length" ,
113+ max_length = self .tokenizer .model_max_length ,
114+ truncation = True ,
115+ return_tensors = "pt" ,
116+ )
117+ text_input_ids = text_inputs .input_ids
118+ untruncated_ids = self .tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
119+
120+ if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (
121+ text_input_ids , untruncated_ids
122+ ):
123+ removed_text = self .tokenizer .batch_decode (
124+ untruncated_ids [:, self .tokenizer .model_max_length - 1 : - 1 ]
125+ )
126+ logger .warning (
127+ "The following part of your input was truncated because CLIP can only handle sequences up to"
128+ f" { self .tokenizer .model_max_length } tokens: { removed_text } "
129+ )
130+
131+ if hasattr (self .text_encoder .config , "use_attention_mask" ) and self .text_encoder .config .use_attention_mask :
132+ attention_mask = text_inputs .attention_mask .to (device )
133+ else :
134+ attention_mask = None
135+
136+ if clip_skip is None :
137+ prompt_embeds = self .text_encoder (text_input_ids .to (device ), attention_mask = attention_mask )
138+ prompt_embeds = prompt_embeds [0 ]
139+ else :
140+ prompt_embeds = self .text_encoder (
141+ text_input_ids .to (device ), attention_mask = attention_mask , output_hidden_states = True
142+ )
143+ # Access the `hidden_states` first, that contains a tuple of
144+ # all the hidden states from the encoder layers. Then index into
145+ # the tuple to access the hidden states from the desired layer.
146+ prompt_embeds = prompt_embeds [- 1 ][- (clip_skip + 1 )]
147+ # We also need to apply the final LayerNorm here to not mess with the
148+ # representations. The `last_hidden_states` that we typically use for
149+ # obtaining the final prompt representations passes through the LayerNorm
150+ # layer.
151+ prompt_embeds = self .text_encoder .text_model .final_layer_norm (prompt_embeds )
152+
153+ if self .text_encoder is not None :
154+ prompt_embeds_dtype = self .text_encoder .dtype
155+ elif self .unet is not None :
156+ prompt_embeds_dtype = self .unet .dtype
157+ else :
158+ prompt_embeds_dtype = prompt_embeds .dtype
159+
160+ prompt_embeds = prompt_embeds .to (dtype = prompt_embeds_dtype , device = device )
161+
162+ bs_embed , seq_len , _ = prompt_embeds .shape
163+ # duplicate text embeddings for each generation per prompt, using mps friendly method
164+ prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
165+ prompt_embeds = prompt_embeds .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
166+
167+ # get unconditional embeddings for classifier free guidance
168+ if do_classifier_free_guidance and negative_prompt_embeds is None :
169+ uncond_tokens : List [str ]
170+ if negative_prompt is None :
171+ uncond_tokens = ["" ] * batch_size
172+ elif prompt is not None and type (prompt ) is not type (negative_prompt ):
173+ raise TypeError (
174+ f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
175+ f" { type (prompt )} ."
176+ )
177+ elif isinstance (negative_prompt , str ):
178+ uncond_tokens = [negative_prompt ]
179+ elif batch_size != len (negative_prompt ):
180+ raise ValueError (
181+ f"`negative_prompt`: { negative_prompt } has batch size { len (negative_prompt )} , but `prompt`:"
182+ f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches"
183+ " the batch size of `prompt`."
184+ )
185+ else :
186+ uncond_tokens = negative_prompt
187+
188+ # textual inversion: process multi-vector tokens if necessary
189+ if isinstance (self , TextualInversionLoaderMixin ):
190+ uncond_tokens = self .maybe_convert_prompt (uncond_tokens , self .tokenizer )
191+
192+ max_length = prompt_embeds .shape [1 ]
193+ uncond_input = self .tokenizer (
194+ uncond_tokens ,
195+ padding = "max_length" ,
196+ max_length = max_length ,
197+ truncation = True ,
198+ return_tensors = "pt" ,
199+ )
200+
201+ if hasattr (self .text_encoder .config , "use_attention_mask" ) and self .text_encoder .config .use_attention_mask :
202+ attention_mask = uncond_input .attention_mask .to (device )
203+ else :
204+ attention_mask = None
205+
206+ negative_prompt_embeds = self .text_encoder (
207+ uncond_input .input_ids .to (device ),
208+ attention_mask = attention_mask ,
209+ )
210+ negative_prompt_embeds = negative_prompt_embeds [0 ]
211+
212+ if do_classifier_free_guidance :
213+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
214+ seq_len = negative_prompt_embeds .shape [1 ]
215+
216+ negative_prompt_embeds = negative_prompt_embeds .to (dtype = prompt_embeds_dtype , device = device )
217+
218+ negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
219+ negative_prompt_embeds = negative_prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
220+
221+ if self .text_encoder is not None :
222+ if isinstance (self , StableDiffusionLoraLoaderMixin ) and USE_PEFT_BACKEND :
223+ # Retrieve the original scale by scaling back the LoRA layers
224+ unscale_lora_layers (self .text_encoder , lora_scale )
225+
61226 return prompt_embeds , negative_prompt_embeds
62227
63228 def get_learned_conditioning (self , c ):
@@ -82,6 +247,3 @@ def get_learned_conditioning(self, c):
82247 c = self .frozen_CLIP_embedder_t3 (c )
83248
84249 return c
85-
86- def get_unconditional_conditioning (self , N ):
87- return self .get_learned_conditioning ({"c_crossattn" : [["" ] * N ], "text_info" : None })
0 commit comments