|  | 
| 31 | 31 | from ...models import AutoencoderKL, FluxTransformer2DModel | 
| 32 | 32 | from ...schedulers import FlowMatchEulerDiscreteScheduler | 
| 33 | 33 | from ...utils import ( | 
| 34 |  | -    USE_PEFT_BACKEND, | 
| 35 |  | -    deprecate, | 
| 36 | 34 |     is_torch_xla_available, | 
| 37 | 35 |     logging, | 
| 38 | 36 |     replace_example_docstring, | 
| 39 |  | -    scale_lora_layers, | 
| 40 |  | -    unscale_lora_layers, | 
| 41 | 37 | ) | 
| 42 | 38 | from ...utils.torch_utils import randn_tensor | 
| 43 | 39 | from ..pipeline_utils import DiffusionPipeline | 
|  | 40 | +from .pipeline_flux_utils import FluxMixin | 
| 44 | 41 | from .pipeline_output import FluxPipelineOutput | 
| 45 | 42 | 
 | 
| 46 | 43 | 
 | 
| @@ -146,6 +143,7 @@ def retrieve_timesteps( | 
| 146 | 143 | 
 | 
| 147 | 144 | class FluxPipeline( | 
| 148 | 145 |     DiffusionPipeline, | 
|  | 146 | +    FluxMixin, | 
| 149 | 147 |     FluxLoraLoaderMixin, | 
| 150 | 148 |     FromSingleFileMixin, | 
| 151 | 149 |     TextualInversionLoaderMixin, | 
| @@ -215,178 +213,6 @@ def __init__( | 
| 215 | 213 |         ) | 
| 216 | 214 |         self.default_sample_size = 128 | 
| 217 | 215 | 
 | 
| 218 |  | -    def _get_t5_prompt_embeds( | 
| 219 |  | -        self, | 
| 220 |  | -        prompt: Union[str, List[str]] = None, | 
| 221 |  | -        num_images_per_prompt: int = 1, | 
| 222 |  | -        max_sequence_length: int = 512, | 
| 223 |  | -        device: Optional[torch.device] = None, | 
| 224 |  | -        dtype: Optional[torch.dtype] = None, | 
| 225 |  | -    ): | 
| 226 |  | -        device = device or self._execution_device | 
| 227 |  | -        dtype = dtype or self.text_encoder.dtype | 
| 228 |  | - | 
| 229 |  | -        prompt = [prompt] if isinstance(prompt, str) else prompt | 
| 230 |  | -        batch_size = len(prompt) | 
| 231 |  | - | 
| 232 |  | -        if isinstance(self, TextualInversionLoaderMixin): | 
| 233 |  | -            prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) | 
| 234 |  | - | 
| 235 |  | -        text_inputs = self.tokenizer_2( | 
| 236 |  | -            prompt, | 
| 237 |  | -            padding="max_length", | 
| 238 |  | -            max_length=max_sequence_length, | 
| 239 |  | -            truncation=True, | 
| 240 |  | -            return_length=False, | 
| 241 |  | -            return_overflowing_tokens=False, | 
| 242 |  | -            return_tensors="pt", | 
| 243 |  | -        ) | 
| 244 |  | -        text_input_ids = text_inputs.input_ids | 
| 245 |  | -        untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids | 
| 246 |  | - | 
| 247 |  | -        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): | 
| 248 |  | -            removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) | 
| 249 |  | -            logger.warning( | 
| 250 |  | -                "The following part of your input was truncated because `max_sequence_length` is set to " | 
| 251 |  | -                f" {max_sequence_length} tokens: {removed_text}" | 
| 252 |  | -            ) | 
| 253 |  | - | 
| 254 |  | -        prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] | 
| 255 |  | - | 
| 256 |  | -        dtype = self.text_encoder_2.dtype | 
| 257 |  | -        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | 
| 258 |  | - | 
| 259 |  | -        _, seq_len, _ = prompt_embeds.shape | 
| 260 |  | - | 
| 261 |  | -        # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method | 
| 262 |  | -        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | 
| 263 |  | -        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | 
| 264 |  | - | 
| 265 |  | -        return prompt_embeds | 
| 266 |  | - | 
| 267 |  | -    def _get_clip_prompt_embeds( | 
| 268 |  | -        self, | 
| 269 |  | -        prompt: Union[str, List[str]], | 
| 270 |  | -        num_images_per_prompt: int = 1, | 
| 271 |  | -        device: Optional[torch.device] = None, | 
| 272 |  | -    ): | 
| 273 |  | -        device = device or self._execution_device | 
| 274 |  | - | 
| 275 |  | -        prompt = [prompt] if isinstance(prompt, str) else prompt | 
| 276 |  | -        batch_size = len(prompt) | 
| 277 |  | - | 
| 278 |  | -        if isinstance(self, TextualInversionLoaderMixin): | 
| 279 |  | -            prompt = self.maybe_convert_prompt(prompt, self.tokenizer) | 
| 280 |  | - | 
| 281 |  | -        text_inputs = self.tokenizer( | 
| 282 |  | -            prompt, | 
| 283 |  | -            padding="max_length", | 
| 284 |  | -            max_length=self.tokenizer_max_length, | 
| 285 |  | -            truncation=True, | 
| 286 |  | -            return_overflowing_tokens=False, | 
| 287 |  | -            return_length=False, | 
| 288 |  | -            return_tensors="pt", | 
| 289 |  | -        ) | 
| 290 |  | - | 
| 291 |  | -        text_input_ids = text_inputs.input_ids | 
| 292 |  | -        untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids | 
| 293 |  | -        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): | 
| 294 |  | -            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) | 
| 295 |  | -            logger.warning( | 
| 296 |  | -                "The following part of your input was truncated because CLIP can only handle sequences up to" | 
| 297 |  | -                f" {self.tokenizer_max_length} tokens: {removed_text}" | 
| 298 |  | -            ) | 
| 299 |  | -        prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) | 
| 300 |  | - | 
| 301 |  | -        # Use pooled output of CLIPTextModel | 
| 302 |  | -        prompt_embeds = prompt_embeds.pooler_output | 
| 303 |  | -        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) | 
| 304 |  | - | 
| 305 |  | -        # duplicate text embeddings for each generation per prompt, using mps friendly method | 
| 306 |  | -        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) | 
| 307 |  | -        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) | 
| 308 |  | - | 
| 309 |  | -        return prompt_embeds | 
| 310 |  | - | 
| 311 |  | -    def encode_prompt( | 
| 312 |  | -        self, | 
| 313 |  | -        prompt: Union[str, List[str]], | 
| 314 |  | -        prompt_2: Optional[Union[str, List[str]]] = None, | 
| 315 |  | -        device: Optional[torch.device] = None, | 
| 316 |  | -        num_images_per_prompt: int = 1, | 
| 317 |  | -        prompt_embeds: Optional[torch.FloatTensor] = None, | 
| 318 |  | -        pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | 
| 319 |  | -        max_sequence_length: int = 512, | 
| 320 |  | -        lora_scale: Optional[float] = None, | 
| 321 |  | -    ): | 
| 322 |  | -        r""" | 
| 323 |  | -
 | 
| 324 |  | -        Args: | 
| 325 |  | -            prompt (`str` or `List[str]`, *optional*): | 
| 326 |  | -                prompt to be encoded | 
| 327 |  | -            prompt_2 (`str` or `List[str]`, *optional*): | 
| 328 |  | -                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is | 
| 329 |  | -                used in all text-encoders | 
| 330 |  | -            device: (`torch.device`): | 
| 331 |  | -                torch device | 
| 332 |  | -            num_images_per_prompt (`int`): | 
| 333 |  | -                number of images that should be generated per prompt | 
| 334 |  | -            prompt_embeds (`torch.FloatTensor`, *optional*): | 
| 335 |  | -                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | 
| 336 |  | -                provided, text embeddings will be generated from `prompt` input argument. | 
| 337 |  | -            pooled_prompt_embeds (`torch.FloatTensor`, *optional*): | 
| 338 |  | -                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. | 
| 339 |  | -                If not provided, pooled text embeddings will be generated from `prompt` input argument. | 
| 340 |  | -            lora_scale (`float`, *optional*): | 
| 341 |  | -                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. | 
| 342 |  | -        """ | 
| 343 |  | -        device = device or self._execution_device | 
| 344 |  | - | 
| 345 |  | -        # set lora scale so that monkey patched LoRA | 
| 346 |  | -        # function of text encoder can correctly access it | 
| 347 |  | -        if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): | 
| 348 |  | -            self._lora_scale = lora_scale | 
| 349 |  | - | 
| 350 |  | -            # dynamically adjust the LoRA scale | 
| 351 |  | -            if self.text_encoder is not None and USE_PEFT_BACKEND: | 
| 352 |  | -                scale_lora_layers(self.text_encoder, lora_scale) | 
| 353 |  | -            if self.text_encoder_2 is not None and USE_PEFT_BACKEND: | 
| 354 |  | -                scale_lora_layers(self.text_encoder_2, lora_scale) | 
| 355 |  | - | 
| 356 |  | -        prompt = [prompt] if isinstance(prompt, str) else prompt | 
| 357 |  | - | 
| 358 |  | -        if prompt_embeds is None: | 
| 359 |  | -            prompt_2 = prompt_2 or prompt | 
| 360 |  | -            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 | 
| 361 |  | - | 
| 362 |  | -            # We only use the pooled prompt output from the CLIPTextModel | 
| 363 |  | -            pooled_prompt_embeds = self._get_clip_prompt_embeds( | 
| 364 |  | -                prompt=prompt, | 
| 365 |  | -                device=device, | 
| 366 |  | -                num_images_per_prompt=num_images_per_prompt, | 
| 367 |  | -            ) | 
| 368 |  | -            prompt_embeds = self._get_t5_prompt_embeds( | 
| 369 |  | -                prompt=prompt_2, | 
| 370 |  | -                num_images_per_prompt=num_images_per_prompt, | 
| 371 |  | -                max_sequence_length=max_sequence_length, | 
| 372 |  | -                device=device, | 
| 373 |  | -            ) | 
| 374 |  | - | 
| 375 |  | -        if self.text_encoder is not None: | 
| 376 |  | -            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: | 
| 377 |  | -                # Retrieve the original scale by scaling back the LoRA layers | 
| 378 |  | -                unscale_lora_layers(self.text_encoder, lora_scale) | 
| 379 |  | - | 
| 380 |  | -        if self.text_encoder_2 is not None: | 
| 381 |  | -            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: | 
| 382 |  | -                # Retrieve the original scale by scaling back the LoRA layers | 
| 383 |  | -                unscale_lora_layers(self.text_encoder_2, lora_scale) | 
| 384 |  | - | 
| 385 |  | -        dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype | 
| 386 |  | -        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) | 
| 387 |  | - | 
| 388 |  | -        return prompt_embeds, pooled_prompt_embeds, text_ids | 
| 389 |  | - | 
| 390 | 216 |     def encode_image(self, image, device, num_images_per_prompt): | 
| 391 | 217 |         dtype = next(self.image_encoder.parameters()).dtype | 
| 392 | 218 | 
 | 
| @@ -503,97 +329,6 @@ def check_inputs( | 
| 503 | 329 |         if max_sequence_length is not None and max_sequence_length > 512: | 
| 504 | 330 |             raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") | 
| 505 | 331 | 
 | 
| 506 |  | -    @staticmethod | 
| 507 |  | -    def _prepare_latent_image_ids(batch_size, height, width, device, dtype): | 
| 508 |  | -        latent_image_ids = torch.zeros(height, width, 3) | 
| 509 |  | -        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] | 
| 510 |  | -        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] | 
| 511 |  | - | 
| 512 |  | -        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape | 
| 513 |  | - | 
| 514 |  | -        latent_image_ids = latent_image_ids.reshape( | 
| 515 |  | -            latent_image_id_height * latent_image_id_width, latent_image_id_channels | 
| 516 |  | -        ) | 
| 517 |  | - | 
| 518 |  | -        return latent_image_ids.to(device=device, dtype=dtype) | 
| 519 |  | - | 
| 520 |  | -    @staticmethod | 
| 521 |  | -    def _pack_latents(latents, batch_size, num_channels_latents, height, width): | 
| 522 |  | -        latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) | 
| 523 |  | -        latents = latents.permute(0, 2, 4, 1, 3, 5) | 
| 524 |  | -        latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) | 
| 525 |  | - | 
| 526 |  | -        return latents | 
| 527 |  | - | 
| 528 |  | -    @staticmethod | 
| 529 |  | -    def _unpack_latents(latents, height, width, vae_scale_factor): | 
| 530 |  | -        batch_size, num_patches, channels = latents.shape | 
| 531 |  | - | 
| 532 |  | -        # VAE applies 8x compression on images but we must also account for packing which requires | 
| 533 |  | -        # latent height and width to be divisible by 2. | 
| 534 |  | -        height = 2 * (int(height) // (vae_scale_factor * 2)) | 
| 535 |  | -        width = 2 * (int(width) // (vae_scale_factor * 2)) | 
| 536 |  | - | 
| 537 |  | -        latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) | 
| 538 |  | -        latents = latents.permute(0, 3, 1, 4, 2, 5) | 
| 539 |  | - | 
| 540 |  | -        latents = latents.reshape(batch_size, channels // (2 * 2), height, width) | 
| 541 |  | - | 
| 542 |  | -        return latents | 
| 543 |  | - | 
| 544 |  | -    def enable_vae_slicing(self): | 
| 545 |  | -        r""" | 
| 546 |  | -        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to | 
| 547 |  | -        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. | 
| 548 |  | -        """ | 
| 549 |  | -        depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." | 
| 550 |  | -        deprecate( | 
| 551 |  | -            "enable_vae_slicing", | 
| 552 |  | -            "0.40.0", | 
| 553 |  | -            depr_message, | 
| 554 |  | -        ) | 
| 555 |  | -        self.vae.enable_slicing() | 
| 556 |  | - | 
| 557 |  | -    def disable_vae_slicing(self): | 
| 558 |  | -        r""" | 
| 559 |  | -        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to | 
| 560 |  | -        computing decoding in one step. | 
| 561 |  | -        """ | 
| 562 |  | -        depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." | 
| 563 |  | -        deprecate( | 
| 564 |  | -            "disable_vae_slicing", | 
| 565 |  | -            "0.40.0", | 
| 566 |  | -            depr_message, | 
| 567 |  | -        ) | 
| 568 |  | -        self.vae.disable_slicing() | 
| 569 |  | - | 
| 570 |  | -    def enable_vae_tiling(self): | 
| 571 |  | -        r""" | 
| 572 |  | -        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to | 
| 573 |  | -        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow | 
| 574 |  | -        processing larger images. | 
| 575 |  | -        """ | 
| 576 |  | -        depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." | 
| 577 |  | -        deprecate( | 
| 578 |  | -            "enable_vae_tiling", | 
| 579 |  | -            "0.40.0", | 
| 580 |  | -            depr_message, | 
| 581 |  | -        ) | 
| 582 |  | -        self.vae.enable_tiling() | 
| 583 |  | - | 
| 584 |  | -    def disable_vae_tiling(self): | 
| 585 |  | -        r""" | 
| 586 |  | -        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to | 
| 587 |  | -        computing decoding in one step. | 
| 588 |  | -        """ | 
| 589 |  | -        depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." | 
| 590 |  | -        deprecate( | 
| 591 |  | -            "disable_vae_tiling", | 
| 592 |  | -            "0.40.0", | 
| 593 |  | -            depr_message, | 
| 594 |  | -        ) | 
| 595 |  | -        self.vae.disable_tiling() | 
| 596 |  | - | 
| 597 | 332 |     def prepare_latents( | 
| 598 | 333 |         self, | 
| 599 | 334 |         batch_size, | 
| @@ -629,26 +364,6 @@ def prepare_latents( | 
| 629 | 364 | 
 | 
| 630 | 365 |         return latents, latent_image_ids | 
| 631 | 366 | 
 | 
| 632 |  | -    @property | 
| 633 |  | -    def guidance_scale(self): | 
| 634 |  | -        return self._guidance_scale | 
| 635 |  | - | 
| 636 |  | -    @property | 
| 637 |  | -    def joint_attention_kwargs(self): | 
| 638 |  | -        return self._joint_attention_kwargs | 
| 639 |  | - | 
| 640 |  | -    @property | 
| 641 |  | -    def num_timesteps(self): | 
| 642 |  | -        return self._num_timesteps | 
| 643 |  | - | 
| 644 |  | -    @property | 
| 645 |  | -    def current_timestep(self): | 
| 646 |  | -        return self._current_timestep | 
| 647 |  | - | 
| 648 |  | -    @property | 
| 649 |  | -    def interrupt(self): | 
| 650 |  | -        return self._interrupt | 
| 651 |  | - | 
| 652 | 367 |     @torch.no_grad() | 
| 653 | 368 |     @replace_example_docstring(EXAMPLE_DOC_STRING) | 
| 654 | 369 |     def __call__( | 
|  | 
0 commit comments