|  | 
| 13 | 13 | # limitations under the License. | 
| 14 | 14 | 
 | 
| 15 | 15 | import inspect | 
| 16 |  | -from typing import Any, Callable, Dict, List, Optional, Union, Tuple | 
|  | 16 | +from typing import Any, Callable, Dict, List, Optional, Union | 
|  | 17 | + | 
| 17 | 18 | import numpy as np | 
| 18 | 19 | import torch | 
| 19 | 20 | from transformers import ( | 
|  | 21 | +    CLIPImageProcessor, | 
| 20 | 22 |     CLIPTextModel, | 
| 21 | 23 |     CLIPTokenizer, | 
|  | 24 | +    CLIPVisionModelWithProjection, | 
| 22 | 25 |     T5EncoderModel, | 
| 23 | 26 |     T5TokenizerFast, | 
| 24 |  | -    CLIPVisionModelWithProjection, | 
| 25 |  | -    CLIPImageProcessor | 
| 26 | 27 | ) | 
| 27 |  | -from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput | 
| 28 |  | -from diffusers.image_processor import VaeImageProcessor, PipelineImageInput | 
| 29 |  | -from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin | 
|  | 28 | + | 
|  | 29 | +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor | 
|  | 30 | +from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin | 
| 30 | 31 | from diffusers.models.autoencoders import AutoencoderKL | 
| 31 | 32 | from diffusers.models.transformers import FluxTransformer2DModel | 
|  | 33 | +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput | 
|  | 34 | +from diffusers.pipelines.pipeline_utils import DiffusionPipeline | 
| 32 | 35 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | 
| 33 | 36 | from diffusers.utils import ( | 
| 34 | 37 |     USE_PEFT_BACKEND, | 
|  | 
| 39 | 42 |     unscale_lora_layers, | 
| 40 | 43 | ) | 
| 41 | 44 | from diffusers.utils.torch_utils import randn_tensor | 
| 42 |  | -from diffusers.pipelines.pipeline_utils import DiffusionPipeline | 
|  | 45 | + | 
| 43 | 46 | 
 | 
| 44 | 47 | if is_torch_xla_available(): | 
| 45 | 48 |     import torch_xla.core.xla_model as xm | 
|  | 
| 57 | 60 |         >>> from diffusers import DiffusionPipeline | 
| 58 | 61 | 
 | 
| 59 | 62 |         >>> pipe = DiffusionPipeline.from_pretrained( | 
| 60 |  | -        >>>     "black-forest-labs/FLUX.1-dev",  | 
|  | 63 | +        >>>     "black-forest-labs/FLUX.1-dev", | 
| 61 | 64 |         >>>     custom_pipeline="pipeline_flux_semantic_guidance", | 
| 62 | 65 |         >>>     torch_dtype=torch.bfloat16 | 
| 63 | 66 |         >>> ) | 
| @@ -319,7 +322,6 @@ def _get_clip_prompt_embeds( | 
| 319 | 322 | 
 | 
| 320 | 323 |         return prompt_embeds | 
| 321 | 324 | 
 | 
| 322 |  | - | 
| 323 | 325 |     def encode_prompt( | 
| 324 | 326 |         self, | 
| 325 | 327 |         prompt: Union[str, List[str]], | 
| @@ -400,18 +402,18 @@ def encode_prompt( | 
| 400 | 402 |         return prompt_embeds, pooled_prompt_embeds, text_ids | 
| 401 | 403 | 
 | 
| 402 | 404 |     def encode_text_with_editing( | 
| 403 |  | -            self, | 
| 404 |  | -            prompt: Union[str, List[str]], | 
| 405 |  | -            prompt_2: Union[str, List[str]], | 
| 406 |  | -            pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | 
| 407 |  | -            editing_prompt: Optional[List[str]] = None, | 
| 408 |  | -            editing_prompt_2: Optional[List[str]] = None, | 
| 409 |  | -            editing_prompt_embeds: Optional[torch.FloatTensor] = None, | 
| 410 |  | -            pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, | 
| 411 |  | -            device: Optional[torch.device] = None, | 
| 412 |  | -            num_images_per_prompt: int = 1, | 
| 413 |  | -            max_sequence_length: int = 512, | 
| 414 |  | -            lora_scale: Optional[float] = None, | 
|  | 405 | +        self, | 
|  | 406 | +        prompt: Union[str, List[str]], | 
|  | 407 | +        prompt_2: Union[str, List[str]], | 
|  | 408 | +        pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | 
|  | 409 | +        editing_prompt: Optional[List[str]] = None, | 
|  | 410 | +        editing_prompt_2: Optional[List[str]] = None, | 
|  | 411 | +        editing_prompt_embeds: Optional[torch.FloatTensor] = None, | 
|  | 412 | +        pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, | 
|  | 413 | +        device: Optional[torch.device] = None, | 
|  | 414 | +        num_images_per_prompt: int = 1, | 
|  | 415 | +        max_sequence_length: int = 512, | 
|  | 416 | +        lora_scale: Optional[float] = None, | 
| 415 | 417 |     ): | 
| 416 | 418 |         """ | 
| 417 | 419 |         Encode text prompts with editing prompts and negative prompts for semantic guidance. | 
| @@ -500,8 +502,15 @@ def encode_text_with_editing( | 
| 500 | 502 |                 editing_prompt_embeds[idx] = torch.cat([editing_prompt_embeds[idx]] * batch_size, dim=0) | 
| 501 | 503 |                 pooled_editing_prompt_embeds[idx] = torch.cat([pooled_editing_prompt_embeds[idx]] * batch_size, dim=0) | 
| 502 | 504 | 
 | 
| 503 |  | -        return (prompt_embeds, pooled_prompt_embeds, editing_prompt_embeds, | 
| 504 |  | -                pooled_editing_prompt_embeds, text_ids, edit_text_ids, enabled_editing_prompts) | 
|  | 505 | +        return ( | 
|  | 506 | +            prompt_embeds, | 
|  | 507 | +            pooled_prompt_embeds, | 
|  | 508 | +            editing_prompt_embeds, | 
|  | 509 | +            pooled_editing_prompt_embeds, | 
|  | 510 | +            text_ids, | 
|  | 511 | +            edit_text_ids, | 
|  | 512 | +            enabled_editing_prompts, | 
|  | 513 | +        ) | 
| 505 | 514 | 
 | 
| 506 | 515 |     def encode_image(self, image, device, num_images_per_prompt): | 
| 507 | 516 |         dtype = next(self.image_encoder.parameters()).dtype | 
| @@ -546,27 +555,27 @@ def prepare_ip_adapter_image_embeds( | 
| 546 | 555 |         return ip_adapter_image_embeds | 
| 547 | 556 | 
 | 
| 548 | 557 |     def check_inputs( | 
| 549 |  | -            self, | 
| 550 |  | -            prompt, | 
| 551 |  | -            prompt_2, | 
| 552 |  | -            height, | 
| 553 |  | -            width, | 
| 554 |  | -            negative_prompt=None, | 
| 555 |  | -            negative_prompt_2=None, | 
| 556 |  | -            prompt_embeds=None, | 
| 557 |  | -            negative_prompt_embeds=None, | 
| 558 |  | -            pooled_prompt_embeds=None, | 
| 559 |  | -            negative_pooled_prompt_embeds=None, | 
| 560 |  | -            callback_on_step_end_tensor_inputs=None, | 
| 561 |  | -            max_sequence_length=None, | 
|  | 558 | +        self, | 
|  | 559 | +        prompt, | 
|  | 560 | +        prompt_2, | 
|  | 561 | +        height, | 
|  | 562 | +        width, | 
|  | 563 | +        negative_prompt=None, | 
|  | 564 | +        negative_prompt_2=None, | 
|  | 565 | +        prompt_embeds=None, | 
|  | 566 | +        negative_prompt_embeds=None, | 
|  | 567 | +        pooled_prompt_embeds=None, | 
|  | 568 | +        negative_pooled_prompt_embeds=None, | 
|  | 569 | +        callback_on_step_end_tensor_inputs=None, | 
|  | 570 | +        max_sequence_length=None, | 
| 562 | 571 |     ): | 
| 563 | 572 |         if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: | 
| 564 | 573 |             logger.warning( | 
| 565 | 574 |                 f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" | 
| 566 | 575 |             ) | 
| 567 | 576 | 
 | 
| 568 | 577 |         if callback_on_step_end_tensor_inputs is not None and not all( | 
| 569 |  | -                k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs | 
|  | 578 | +            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs | 
| 570 | 579 |         ): | 
| 571 | 580 |             raise ValueError( | 
| 572 | 581 |                 f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" | 
| @@ -743,47 +752,47 @@ def interrupt(self): | 
| 743 | 752 |     @torch.no_grad() | 
| 744 | 753 |     @replace_example_docstring(EXAMPLE_DOC_STRING) | 
| 745 | 754 |     def __call__( | 
| 746 |  | -            self, | 
| 747 |  | -            prompt: Union[str, List[str]] = None, | 
| 748 |  | -            prompt_2: Optional[Union[str, List[str]]] = None, | 
| 749 |  | -            negative_prompt: Union[str, List[str]] = None, | 
| 750 |  | -            negative_prompt_2: Optional[Union[str, List[str]]] = None, | 
| 751 |  | -            true_cfg_scale: float = 1.0, | 
| 752 |  | -            height: Optional[int] = None, | 
| 753 |  | -            width: Optional[int] = None, | 
| 754 |  | -            num_inference_steps: int = 28, | 
| 755 |  | -            sigmas: Optional[List[float]] = None, | 
| 756 |  | -            guidance_scale: float = 3.5, | 
| 757 |  | -            num_images_per_prompt: Optional[int] = 1, | 
| 758 |  | -            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 
| 759 |  | -            latents: Optional[torch.FloatTensor] = None, | 
| 760 |  | -            prompt_embeds: Optional[torch.FloatTensor] = None, | 
| 761 |  | -            pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | 
| 762 |  | -            ip_adapter_image: Optional[PipelineImageInput] = None, | 
| 763 |  | -            ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, | 
| 764 |  | -            negative_ip_adapter_image: Optional[PipelineImageInput] = None, | 
| 765 |  | -            negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, | 
| 766 |  | -            negative_prompt_embeds: Optional[torch.FloatTensor] = None, | 
| 767 |  | -            negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | 
| 768 |  | -            output_type: Optional[str] = "pil", | 
| 769 |  | -            return_dict: bool = True, | 
| 770 |  | -            joint_attention_kwargs: Optional[Dict[str, Any]] = None, | 
| 771 |  | -            callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | 
| 772 |  | -            callback_on_step_end_tensor_inputs: List[str] = ["latents"], | 
| 773 |  | -            max_sequence_length: int = 512, | 
| 774 |  | -            editing_prompt: Optional[Union[str, List[str]]] = None, | 
| 775 |  | -            editing_prompt_2: Optional[Union[str, List[str]]] = None, | 
| 776 |  | -            editing_prompt_embeds: Optional[torch.FloatTensor] = None, | 
| 777 |  | -            pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, | 
| 778 |  | -            reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, | 
| 779 |  | -            edit_guidance_scale: Optional[Union[float, List[float]]] = 5, | 
| 780 |  | -            edit_warmup_steps: Optional[Union[int, List[int]]] = 8, | 
| 781 |  | -            edit_cooldown_steps: Optional[Union[int, List[int]]] = None, | 
| 782 |  | -            edit_threshold: Optional[Union[float, List[float]]] = 0.9, | 
| 783 |  | -            edit_momentum_scale: Optional[float] = 0.1, | 
| 784 |  | -            edit_mom_beta: Optional[float] = 0.4, | 
| 785 |  | -            edit_weights: Optional[List[float]] = None, | 
| 786 |  | -            sem_guidance: Optional[List[torch.Tensor]] = None, | 
|  | 755 | +        self, | 
|  | 756 | +        prompt: Union[str, List[str]] = None, | 
|  | 757 | +        prompt_2: Optional[Union[str, List[str]]] = None, | 
|  | 758 | +        negative_prompt: Union[str, List[str]] = None, | 
|  | 759 | +        negative_prompt_2: Optional[Union[str, List[str]]] = None, | 
|  | 760 | +        true_cfg_scale: float = 1.0, | 
|  | 761 | +        height: Optional[int] = None, | 
|  | 762 | +        width: Optional[int] = None, | 
|  | 763 | +        num_inference_steps: int = 28, | 
|  | 764 | +        sigmas: Optional[List[float]] = None, | 
|  | 765 | +        guidance_scale: float = 3.5, | 
|  | 766 | +        num_images_per_prompt: Optional[int] = 1, | 
|  | 767 | +        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 
|  | 768 | +        latents: Optional[torch.FloatTensor] = None, | 
|  | 769 | +        prompt_embeds: Optional[torch.FloatTensor] = None, | 
|  | 770 | +        pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | 
|  | 771 | +        ip_adapter_image: Optional[PipelineImageInput] = None, | 
|  | 772 | +        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, | 
|  | 773 | +        negative_ip_adapter_image: Optional[PipelineImageInput] = None, | 
|  | 774 | +        negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, | 
|  | 775 | +        negative_prompt_embeds: Optional[torch.FloatTensor] = None, | 
|  | 776 | +        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | 
|  | 777 | +        output_type: Optional[str] = "pil", | 
|  | 778 | +        return_dict: bool = True, | 
|  | 779 | +        joint_attention_kwargs: Optional[Dict[str, Any]] = None, | 
|  | 780 | +        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | 
|  | 781 | +        callback_on_step_end_tensor_inputs: List[str] = ["latents"], | 
|  | 782 | +        max_sequence_length: int = 512, | 
|  | 783 | +        editing_prompt: Optional[Union[str, List[str]]] = None, | 
|  | 784 | +        editing_prompt_2: Optional[Union[str, List[str]]] = None, | 
|  | 785 | +        editing_prompt_embeds: Optional[torch.FloatTensor] = None, | 
|  | 786 | +        pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, | 
|  | 787 | +        reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, | 
|  | 788 | +        edit_guidance_scale: Optional[Union[float, List[float]]] = 5, | 
|  | 789 | +        edit_warmup_steps: Optional[Union[int, List[int]]] = 8, | 
|  | 790 | +        edit_cooldown_steps: Optional[Union[int, List[int]]] = None, | 
|  | 791 | +        edit_threshold: Optional[Union[float, List[float]]] = 0.9, | 
|  | 792 | +        edit_momentum_scale: Optional[float] = 0.1, | 
|  | 793 | +        edit_mom_beta: Optional[float] = 0.4, | 
|  | 794 | +        edit_weights: Optional[List[float]] = None, | 
|  | 795 | +        sem_guidance: Optional[List[torch.Tensor]] = None, | 
| 787 | 796 |     ): | 
| 788 | 797 |         r""" | 
| 789 | 798 |         Function invoked when calling the pipeline for generation. | 
| @@ -1037,7 +1046,9 @@ def __call__( | 
| 1037 | 1046 |             min_edit_warmup_steps = 0 | 
| 1038 | 1047 | 
 | 
| 1039 | 1048 |         if edit_cooldown_steps: | 
| 1040 |  | -            tmp_e_cooldown_steps = edit_cooldown_steps if isinstance(edit_cooldown_steps, list) else [edit_cooldown_steps] | 
|  | 1049 | +            tmp_e_cooldown_steps = ( | 
|  | 1050 | +                edit_cooldown_steps if isinstance(edit_cooldown_steps, list) else [edit_cooldown_steps] | 
|  | 1051 | +            ) | 
| 1041 | 1052 |             max_edit_cooldown_steps = min(max(tmp_e_cooldown_steps), num_inference_steps) | 
| 1042 | 1053 |         else: | 
| 1043 | 1054 |             max_edit_cooldown_steps = num_inference_steps | 
| @@ -1110,7 +1121,9 @@ def __call__( | 
| 1110 | 1121 | 
 | 
| 1111 | 1122 |                 if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps: | 
| 1112 | 1123 |                     noise_pred_edit_concepts = [] | 
| 1113 |  | -                    for e_embed, pooled_e_embed, e_text_id in zip(editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids): | 
|  | 1124 | +                    for e_embed, pooled_e_embed, e_text_id in zip( | 
|  | 1125 | +                        editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids | 
|  | 1126 | +                    ): | 
| 1114 | 1127 |                         noise_pred_edit = self.transformer( | 
| 1115 | 1128 |                             hidden_states=latents, | 
| 1116 | 1129 |                             timestep=timestep / 1000, | 
| @@ -1160,7 +1173,6 @@ def __call__( | 
| 1160 | 1173 |                     # noise_guidance_edit = torch.zeros_like(noise_guidance) | 
| 1161 | 1174 |                     warmup_inds = [] | 
| 1162 | 1175 |                     for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): | 
| 1163 |  | - | 
| 1164 | 1176 |                         if isinstance(edit_guidance_scale, list): | 
| 1165 | 1177 |                             edit_guidance_scale_c = edit_guidance_scale[c] | 
| 1166 | 1178 |                         else: | 
| @@ -1247,9 +1259,7 @@ def __call__( | 
| 1247 | 1259 |                         concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) | 
| 1248 | 1260 |                         # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) | 
| 1249 | 1261 | 
 | 
| 1250 |  | -                        noise_guidance_edit_tmp = torch.index_select( | 
| 1251 |  | -                            noise_guidance_edit.to(device), 0, warmup_inds | 
| 1252 |  | -                        ) | 
|  | 1262 | +                        noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds) | 
| 1253 | 1263 |                         noise_guidance_edit_tmp = torch.einsum( | 
| 1254 | 1264 |                             "cb,cbij->bij", concept_weights_tmp, noise_guidance_edit_tmp | 
| 1255 | 1265 |                         ) | 
| @@ -1325,4 +1335,6 @@ def __call__( | 
| 1325 | 1335 |         if not return_dict: | 
| 1326 | 1336 |             return (image,) | 
| 1327 | 1337 | 
 | 
| 1328 |  | -        return FluxPipelineOutput(image, ) | 
|  | 1338 | +        return FluxPipelineOutput( | 
|  | 1339 | +            image, | 
|  | 1340 | +        ) | 
0 commit comments