|
23 | 23 |
|
24 | 24 | from ...callbacks import MultiPipelineCallbacks, PipelineCallback |
25 | 25 | from ...image_processor import PixArtImageProcessor |
| 26 | +from ...loaders import SanaLoraLoaderMixin |
26 | 27 | from ...models import AutoencoderDC, SanaTransformer2DModel |
27 | 28 | from ...models.attention_processor import PAGCFGSanaLinearAttnProcessor2_0, PAGIdentitySanaLinearAttnProcessor2_0 |
28 | 29 | from ...schedulers import FlowMatchEulerDiscreteScheduler |
29 | 30 | from ...utils import ( |
30 | 31 | BACKENDS_MAPPING, |
| 32 | + USE_PEFT_BACKEND, |
31 | 33 | is_bs4_available, |
32 | 34 | is_ftfy_available, |
33 | 35 | logging, |
34 | 36 | replace_example_docstring, |
| 37 | + scale_lora_layers, |
| 38 | + unscale_lora_layers, |
35 | 39 | ) |
36 | 40 | from ...utils.torch_utils import randn_tensor |
37 | 41 | from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput |
@@ -185,6 +189,7 @@ def encode_prompt( |
185 | 189 | clean_caption: bool = False, |
186 | 190 | max_sequence_length: int = 300, |
187 | 191 | complex_human_instruction: Optional[List[str]] = None, |
| 192 | + lora_scale: Optional[float] = None, |
188 | 193 | ): |
189 | 194 | r""" |
190 | 195 | Encodes the prompt into text encoder hidden states. |
@@ -218,6 +223,15 @@ def encode_prompt( |
218 | 223 | if device is None: |
219 | 224 | device = self._execution_device |
220 | 225 |
|
| 226 | + # set lora scale so that monkey patched LoRA |
| 227 | + # function of text encoder can correctly access it |
| 228 | + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): |
| 229 | + self._lora_scale = lora_scale |
| 230 | + |
| 231 | + # dynamically adjust the LoRA scale |
| 232 | + if self.text_encoder is not None and USE_PEFT_BACKEND: |
| 233 | + scale_lora_layers(self.text_encoder, lora_scale) |
| 234 | + |
221 | 235 | if prompt is not None and isinstance(prompt, str): |
222 | 236 | batch_size = 1 |
223 | 237 | elif prompt is not None and isinstance(prompt, list): |
@@ -313,6 +327,11 @@ def encode_prompt( |
313 | 327 | negative_prompt_embeds = None |
314 | 328 | negative_prompt_attention_mask = None |
315 | 329 |
|
| 330 | + if self.text_encoder is not None: |
| 331 | + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: |
| 332 | + # Retrieve the original scale by scaling back the LoRA layers |
| 333 | + unscale_lora_layers(self.text_encoder, lora_scale) |
| 334 | + |
316 | 335 | return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask |
317 | 336 |
|
318 | 337 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs |
|
0 commit comments