From 42e11090f0b3effed1535445da78437796e53dad Mon Sep 17 00:00:00 2001 From: Marlon154 Date: Thu, 16 Jan 2025 11:40:52 +0100 Subject: [PATCH 1/4] add community pipeline for semantic guidance for flux --- .../pipeline_flux_semantic_guidance.py | 1328 +++++++++++++++++ 1 file changed, 1328 insertions(+) create mode 100644 examples/community/pipeline_flux_semantic_guidance.py diff --git a/examples/community/pipeline_flux_semantic_guidance.py b/examples/community/pipeline_flux_semantic_guidance.py new file mode 100644 index 000000000000..097831386887 --- /dev/null +++ b/examples/community/pipeline_flux_semantic_guidance.py @@ -0,0 +1,1328 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +import numpy as np +import torch +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, + CLIPVisionModelWithProjection, + CLIPImageProcessor +) +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.image_processor import VaeImageProcessor, PipelineImageInput +from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.transformers import FluxTransformer2DModel +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline + + >>> pipe = DiffusionPipeline.from_pretrained( + >>> "black-forest-labs/FLUX.1-dev", + >>> custom_pipeline="pipeline_flux_semantic_guidance", + >>> torch_dtype=torch.bfloat16 + >>> ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe( + >>> prompt=prompt, + >>> num_inference_steps=28, + >>> guidance_scale=3.5, + >>> editing_prompt=["cat", "dog"], # changes from cat to dog. + >>> reverse_editing_direction=[True, False], + >>> edit_warmup_steps=[6, 8], + >>> edit_guidance_scale=[6, 6.5], + >>> edit_threshold=[0.89, 0.89], + >>> edit_cooldown_steps = [25, 27], + >>> edit_momentum_scale=0.3, + >>> edit_mom_beta=0.6, + >>> generator=torch.Generator(device="cuda").manual_seed(6543), + >>> ).images[0] + >>> image.save("semantic_flux.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxSemanticGuidancePipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, + FluxIPAdapterMixin, +): + r""" + The Flux pipeline for text-to-image generation with semantic guidance. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def encode_text_with_editing( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + editing_prompt: Optional[List[str]] = None, + editing_prompt_2: Optional[List[str]] = None, + editing_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + """ + Encode text prompts with editing prompts and negative prompts for semantic guidance. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. + prompt_2 (`str` or `List[str]`): + The prompt or prompts to guide image generation for second tokenizer. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + editing_prompt (`str` or `List[str]`, *optional*): + The editing prompts for semantic guidance. + editing_prompt_2 (`str` or `List[str]`, *optional*): + The editing prompts for semantic guidance for second tokenizer. + editing_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-computed embeddings for editing prompts. + pooled_editing_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-computed pooled embeddings for editing prompts. + device (`torch.device`, *optional*): + The device to use for computation. + num_images_per_prompt (`int`, defaults to 1): + Number of images to generate per prompt. + max_sequence_length (`int`, defaults to 512): + Maximum sequence length for text encoding. + lora_scale (`float`, *optional*): + Scale factor for LoRA layers if used. + + Returns: + tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, int]: + A tuple containing the prompt embeddings, pooled prompt embeddings, + text IDs, and number of enabled editing prompts. + """ + device = device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError("Prompt must be provided as string or list of strings") + + # Get base prompt embeddings + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # Handle editing prompts + if editing_prompt_embeds is not None: + enabled_editing_prompts = int(editing_prompt_embeds.shape[0]) + edit_text_ids = [] + elif editing_prompt is not None: + editing_prompt_embeds = [] + pooled_editing_prompt_embeds = [] + edit_text_ids = [] + + editing_prompt_2 = editing_prompt if editing_prompt_2 is None else editing_prompt_2 + for edit_1, edit_2 in zip(editing_prompt, editing_prompt_2): + e_prompt_embeds, pooled_embeds, e_ids = self.encode_prompt( + prompt=edit_1, + prompt_2=edit_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + editing_prompt_embeds.append(e_prompt_embeds) + pooled_editing_prompt_embeds.append(pooled_embeds) + edit_text_ids.append(e_ids) + + enabled_editing_prompts = len(editing_prompt) + + else: + edit_text_ids = [] + enabled_editing_prompts = 0 + + if enabled_editing_prompts: + for idx in range(enabled_editing_prompts): + editing_prompt_embeds[idx] = torch.cat([editing_prompt_embeds[idx]] * batch_size, dim=0) + pooled_editing_prompt_embeds[idx] = torch.cat([pooled_editing_prompt_embeds[idx]] * batch_size, dim=0) + + return (prompt_embeds, pooled_prompt_embeds, editing_prompt_embeds, + pooled_editing_prompt_embeds, text_ids, edit_text_ids, enabled_editing_prompts) + + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers + ): + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + + image_embeds.append(single_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + editing_prompt: Optional[Union[str, List[str]]] = None, + editing_prompt_2: Optional[Union[str, List[str]]] = None, + editing_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, + reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, + edit_guidance_scale: Optional[Union[float, List[float]]] = 5, + edit_warmup_steps: Optional[Union[int, List[int]]] = 8, + edit_cooldown_steps: Optional[Union[int, List[int]]] = None, + edit_threshold: Optional[Union[float, List[float]]] = 0.9, + edit_momentum_scale: Optional[float] = 0.1, + edit_mom_beta: Optional[float] = 0.4, + edit_weights: Optional[List[float]] = None, + sem_guidance: Optional[List[torch.Tensor]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + editing_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image editing. If not defined, no editing will be performed. + editing_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image editing. If not defined, will use editing_prompt instead. + editing_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings for editing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, text embeddings will be generated from `editing_prompt` input argument. + reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): + Whether to reverse the editing direction for each editing prompt. + edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): + Guidance scale for the editing process. If provided as a list, each value corresponds to an editing prompt. + edit_warmup_steps (`int` or `List[int]`, *optional*, defaults to 10): + Number of warmup steps for editing guidance. If provided as a list, each value corresponds to an editing prompt. + edit_cooldown_steps (`int` or `List[int]`, *optional*, defaults to None): + Number of cooldown steps for editing guidance. If provided as a list, each value corresponds to an editing prompt. + edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): + Threshold for editing guidance. If provided as a list, each value corresponds to an editing prompt. + edit_momentum_scale (`float`, *optional*, defaults to 0.1): + Scale of momentum to be added to the editing guidance at each diffusion step. + edit_mom_beta (`float`, *optional*, defaults to 0.4): + Beta value for momentum calculation in editing guidance. + edit_weights (`List[float]`, *optional*): + Weights for each editing prompt. + sem_guidance (`List[torch.Tensor]`, *optional*): + Pre-generated semantic guidance. If provided, it will be used instead of calculating guidance from editing prompts. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if editing_prompt: + enable_edit_guidance = True + if isinstance(editing_prompt, str): + editing_prompt = [editing_prompt] + enabled_editing_prompts = len(editing_prompt) + elif editing_prompt_embeds is not None: + enable_edit_guidance = True + enabled_editing_prompts = editing_prompt_embeds.shape[0] + else: + enabled_editing_prompts = 0 + enable_edit_guidance = False + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + editing_prompts_embeds, + pooled_editing_prompt_embeds, + text_ids, + edit_text_ids, + enabled_editing_prompts, + ) = self.encode_text_with_editing( + prompt=prompt, + prompt_2=prompt_2, + pooled_prompt_embeds=pooled_prompt_embeds, + editing_prompt=editing_prompt, + editing_prompt_2=editing_prompt_2, + pooled_editing_prompt_embeds=pooled_editing_prompt_embeds, + lora_scale=lora_scale, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + _, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + negative_prompt_embeds = torch.cat([negative_prompt_embeds] * batch_size, dim=0) + negative_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds] * batch_size, dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + edit_momentum = None + if edit_warmup_steps: + tmp_e_warmup_steps = edit_warmup_steps if isinstance(edit_warmup_steps, list) else [edit_warmup_steps] + min_edit_warmup_steps = min(tmp_e_warmup_steps) + else: + min_edit_warmup_steps = 0 + + if edit_cooldown_steps: + tmp_e_cooldown_steps = edit_cooldown_steps if isinstance(edit_cooldown_steps, list) else [edit_cooldown_steps] + max_edit_cooldown_steps = min(max(tmp_e_cooldown_steps), num_inference_steps) + else: + max_edit_cooldown_steps = num_inference_steps + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.tensor([guidance_scale], device=device) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps: + noise_pred_edit_concepts = [] + for e_embed, pooled_e_embed, e_text_id in zip(editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids): + noise_pred_edit = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_e_embed, + encoder_hidden_states=e_embed, + txt_ids=e_text_id, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred_edit_concepts.append(noise_pred_edit) + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_guidance = true_cfg_scale * (noise_pred - noise_pred_uncond) + else: + noise_pred_uncond = noise_pred + noise_guidance = noise_pred + + if edit_momentum is None: + edit_momentum = torch.zeros_like(noise_guidance) + + if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps: + concept_weights = torch.zeros( + (enabled_editing_prompts, noise_guidance.shape[0]), + device=device, + dtype=noise_guidance.dtype, + ) + noise_guidance_edit = torch.zeros( + (enabled_editing_prompts, *noise_guidance.shape), + device=device, + dtype=noise_guidance.dtype, + ) + # noise_guidance_edit = torch.zeros_like(noise_guidance) + warmup_inds = [] + for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): + + if isinstance(edit_guidance_scale, list): + edit_guidance_scale_c = edit_guidance_scale[c] + else: + edit_guidance_scale_c = edit_guidance_scale + + if isinstance(edit_threshold, list): + edit_threshold_c = edit_threshold[c] + else: + edit_threshold_c = edit_threshold + if isinstance(reverse_editing_direction, list): + reverse_editing_direction_c = reverse_editing_direction[c] + else: + reverse_editing_direction_c = reverse_editing_direction + if edit_weights: + edit_weight_c = edit_weights[c] + else: + edit_weight_c = 1.0 + if isinstance(edit_warmup_steps, list): + edit_warmup_steps_c = edit_warmup_steps[c] + else: + edit_warmup_steps_c = edit_warmup_steps + + if isinstance(edit_cooldown_steps, list): + edit_cooldown_steps_c = edit_cooldown_steps[c] + elif edit_cooldown_steps is None: + edit_cooldown_steps_c = i + 1 + else: + edit_cooldown_steps_c = edit_cooldown_steps + if i >= edit_warmup_steps_c: + warmup_inds.append(c) + if i >= edit_cooldown_steps_c: + noise_guidance_edit[c, :, :, :] = torch.zeros_like(noise_pred_edit_concept) + continue + + if do_true_cfg: + noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond + else: # simple sega + noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred + tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2)) + + tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts) + if reverse_editing_direction_c: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 + concept_weights[c, :] = tmp_weights + + noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c + + # torch.quantile function expects float32 + if noise_guidance_edit_tmp.dtype == torch.float32: + tmp = torch.quantile( + torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + else: + tmp = torch.quantile( + torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32), + edit_threshold_c, + dim=2, + keepdim=False, + ).to(noise_guidance_edit_tmp.dtype) + + noise_guidance_edit_tmp = torch.where( + torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None], + noise_guidance_edit_tmp, + torch.zeros_like(noise_guidance_edit_tmp), + ) + + noise_guidance_edit[c, :, :, :] = noise_guidance_edit_tmp + # noise_guidance_edit[c] = noise_guidance_edit_tmp + + # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp + + warmup_inds = torch.tensor(warmup_inds).to(device) + if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: + concept_weights = concept_weights.to("cpu") # Offload to cpu + noise_guidance_edit = noise_guidance_edit.to("cpu") + + concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds) + concept_weights_tmp = torch.where( + concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp + ) + concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) + # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) + + noise_guidance_edit_tmp = torch.index_select( + noise_guidance_edit.to(device), 0, warmup_inds + ) + noise_guidance_edit_tmp = torch.einsum( + "cb,cbij->bij", concept_weights_tmp, noise_guidance_edit_tmp + ) + noise_guidance_edit_tmp = noise_guidance_edit_tmp + noise_guidance = noise_guidance + noise_guidance_edit_tmp + + del noise_guidance_edit_tmp + del concept_weights_tmp + concept_weights = concept_weights.to(device) + noise_guidance_edit = noise_guidance_edit.to(device) + + concept_weights = torch.where( + concept_weights < 0, torch.zeros_like(concept_weights), concept_weights + ) + + concept_weights = torch.nan_to_num(concept_weights) + + noise_guidance_edit = torch.einsum("cb,cbij->bij", concept_weights, noise_guidance_edit) + + noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum + + edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit + + if warmup_inds.shape[0] == len(noise_pred_edit_concepts): + noise_guidance = noise_guidance + noise_guidance_edit + + if sem_guidance is not None: + edit_guidance = sem_guidance[i].to(device) + noise_guidance = noise_guidance + edit_guidance + + if do_true_cfg: + noise_pred = noise_guidance + noise_pred_uncond + else: + noise_pred = noise_guidance + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(image, ) From a5892d75ec54b820a9a45650083fa68a6da3ae15 Mon Sep 17 00:00:00 2001 From: Marlon154 Date: Mon, 20 Jan 2025 13:16:54 +0100 Subject: [PATCH 2/4] fix imports in community pipeline for semantic guidance for flux --- .../pipeline_flux_semantic_guidance.py | 182 ++++++++++-------- 1 file changed, 97 insertions(+), 85 deletions(-) diff --git a/examples/community/pipeline_flux_semantic_guidance.py b/examples/community/pipeline_flux_semantic_guidance.py index 097831386887..e714bb8baa7e 100644 --- a/examples/community/pipeline_flux_semantic_guidance.py +++ b/examples/community/pipeline_flux_semantic_guidance.py @@ -13,22 +13,25 @@ # limitations under the License. import inspect -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Union + import numpy as np import torch from transformers import ( + CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, + CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast, - CLIPVisionModelWithProjection, - CLIPImageProcessor ) -from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput -from diffusers.image_processor import VaeImageProcessor, PipelineImageInput -from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from diffusers.models.autoencoders import AutoencoderKL from diffusers.models.transformers import FluxTransformer2DModel +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( USE_PEFT_BACKEND, @@ -39,7 +42,7 @@ unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor -from diffusers.pipelines.pipeline_utils import DiffusionPipeline + if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -57,7 +60,7 @@ >>> from diffusers import DiffusionPipeline >>> pipe = DiffusionPipeline.from_pretrained( - >>> "black-forest-labs/FLUX.1-dev", + >>> "black-forest-labs/FLUX.1-dev", >>> custom_pipeline="pipeline_flux_semantic_guidance", >>> torch_dtype=torch.bfloat16 >>> ) @@ -319,7 +322,6 @@ def _get_clip_prompt_embeds( return prompt_embeds - def encode_prompt( self, prompt: Union[str, List[str]], @@ -400,18 +402,18 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids def encode_text_with_editing( - self, - prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - editing_prompt: Optional[List[str]] = None, - editing_prompt_2: Optional[List[str]] = None, - editing_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + editing_prompt: Optional[List[str]] = None, + editing_prompt_2: Optional[List[str]] = None, + editing_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, ): """ Encode text prompts with editing prompts and negative prompts for semantic guidance. @@ -500,8 +502,15 @@ def encode_text_with_editing( editing_prompt_embeds[idx] = torch.cat([editing_prompt_embeds[idx]] * batch_size, dim=0) pooled_editing_prompt_embeds[idx] = torch.cat([pooled_editing_prompt_embeds[idx]] * batch_size, dim=0) - return (prompt_embeds, pooled_prompt_embeds, editing_prompt_embeds, - pooled_editing_prompt_embeds, text_ids, edit_text_ids, enabled_editing_prompts) + return ( + prompt_embeds, + pooled_prompt_embeds, + editing_prompt_embeds, + pooled_editing_prompt_embeds, + text_ids, + edit_text_ids, + enabled_editing_prompts, + ) def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype @@ -546,19 +555,19 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds def check_inputs( - self, - prompt, - prompt_2, - height, - width, - negative_prompt=None, - negative_prompt_2=None, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - max_sequence_length=None, + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, ): if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: logger.warning( @@ -566,7 +575,7 @@ def check_inputs( ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -743,47 +752,47 @@ def interrupt(self): @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - negative_prompt: Union[str, List[str]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - true_cfg_scale: float = 1.0, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 28, - sigmas: Optional[List[float]] = None, - guidance_scale: float = 3.5, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - ip_adapter_image: Optional[PipelineImageInput] = None, - ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, - negative_ip_adapter_image: Optional[PipelineImageInput] = None, - negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, - editing_prompt: Optional[Union[str, List[str]]] = None, - editing_prompt_2: Optional[Union[str, List[str]]] = None, - editing_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, - reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, - edit_guidance_scale: Optional[Union[float, List[float]]] = 5, - edit_warmup_steps: Optional[Union[int, List[int]]] = 8, - edit_cooldown_steps: Optional[Union[int, List[int]]] = None, - edit_threshold: Optional[Union[float, List[float]]] = 0.9, - edit_momentum_scale: Optional[float] = 0.1, - edit_mom_beta: Optional[float] = 0.4, - edit_weights: Optional[List[float]] = None, - sem_guidance: Optional[List[torch.Tensor]] = None, + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + editing_prompt: Optional[Union[str, List[str]]] = None, + editing_prompt_2: Optional[Union[str, List[str]]] = None, + editing_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, + reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, + edit_guidance_scale: Optional[Union[float, List[float]]] = 5, + edit_warmup_steps: Optional[Union[int, List[int]]] = 8, + edit_cooldown_steps: Optional[Union[int, List[int]]] = None, + edit_threshold: Optional[Union[float, List[float]]] = 0.9, + edit_momentum_scale: Optional[float] = 0.1, + edit_mom_beta: Optional[float] = 0.4, + edit_weights: Optional[List[float]] = None, + sem_guidance: Optional[List[torch.Tensor]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -1037,7 +1046,9 @@ def __call__( min_edit_warmup_steps = 0 if edit_cooldown_steps: - tmp_e_cooldown_steps = edit_cooldown_steps if isinstance(edit_cooldown_steps, list) else [edit_cooldown_steps] + tmp_e_cooldown_steps = ( + edit_cooldown_steps if isinstance(edit_cooldown_steps, list) else [edit_cooldown_steps] + ) max_edit_cooldown_steps = min(max(tmp_e_cooldown_steps), num_inference_steps) else: max_edit_cooldown_steps = num_inference_steps @@ -1110,7 +1121,9 @@ def __call__( if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps: noise_pred_edit_concepts = [] - for e_embed, pooled_e_embed, e_text_id in zip(editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids): + for e_embed, pooled_e_embed, e_text_id in zip( + editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids + ): noise_pred_edit = self.transformer( hidden_states=latents, timestep=timestep / 1000, @@ -1160,7 +1173,6 @@ def __call__( # noise_guidance_edit = torch.zeros_like(noise_guidance) warmup_inds = [] for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): - if isinstance(edit_guidance_scale, list): edit_guidance_scale_c = edit_guidance_scale[c] else: @@ -1247,9 +1259,7 @@ def __call__( concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) - noise_guidance_edit_tmp = torch.index_select( - noise_guidance_edit.to(device), 0, warmup_inds - ) + noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds) noise_guidance_edit_tmp = torch.einsum( "cb,cbij->bij", concept_weights_tmp, noise_guidance_edit_tmp ) @@ -1325,4 +1335,6 @@ def __call__( if not return_dict: return (image,) - return FluxPipelineOutput(image, ) + return FluxPipelineOutput( + image, + ) From 481c88a6d68f114dccc3ead0d3ff132958a066f0 Mon Sep 17 00:00:00 2001 From: Marlon May <77202149+Marlon154@users.noreply.github.com> Date: Thu, 23 Jan 2025 13:38:32 +0100 Subject: [PATCH 3/4] Update examples/community/pipeline_flux_semantic_guidance.py Co-authored-by: hlky --- examples/community/pipeline_flux_semantic_guidance.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/community/pipeline_flux_semantic_guidance.py b/examples/community/pipeline_flux_semantic_guidance.py index e714bb8baa7e..5124503c99b7 100644 --- a/examples/community/pipeline_flux_semantic_guidance.py +++ b/examples/community/pipeline_flux_semantic_guidance.py @@ -85,6 +85,7 @@ """ +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, From a79f9946de1c8b8c820ca24771ef2245230954cb Mon Sep 17 00:00:00 2001 From: Marlon154 Date: Thu, 23 Jan 2025 13:48:47 +0100 Subject: [PATCH 4/4] fix community pipeline for semantic guidance for flux --- .../pipeline_flux_semantic_guidance.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/examples/community/pipeline_flux_semantic_guidance.py b/examples/community/pipeline_flux_semantic_guidance.py index 5124503c99b7..3bb080510902 100644 --- a/examples/community/pipeline_flux_semantic_guidance.py +++ b/examples/community/pipeline_flux_semantic_guidance.py @@ -230,6 +230,7 @@ def __init__( ) self.default_sample_size = 128 + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -279,6 +280,7 @@ def _get_t5_prompt_embeds( return prompt_embeds + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], @@ -323,6 +325,7 @@ def _get_clip_prompt_embeds( return prompt_embeds + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -513,6 +516,7 @@ def encode_text_with_editing( enabled_editing_prompts, ) + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype @@ -524,6 +528,7 @@ def encode_image(self, image, device, num_images_per_prompt): image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt ): @@ -555,6 +560,7 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs def check_inputs( self, prompt, @@ -633,6 +639,7 @@ def check_inputs( raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height, width, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] @@ -647,6 +654,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): return latent_image_ids.to(device=device, dtype=dtype) @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) @@ -655,6 +663,7 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width): return latents @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape @@ -670,6 +679,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to @@ -677,6 +687,7 @@ def enable_vae_slicing(self): """ self.vae.enable_slicing() + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to @@ -684,6 +695,7 @@ def disable_vae_slicing(self): """ self.vae.disable_slicing() + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to @@ -692,6 +704,7 @@ def enable_vae_tiling(self): """ self.vae.enable_tiling() + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to @@ -699,6 +712,7 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents def prepare_latents( self, batch_size, @@ -1171,7 +1185,7 @@ def __call__( device=device, dtype=noise_guidance.dtype, ) - # noise_guidance_edit = torch.zeros_like(noise_guidance) + warmup_inds = [] for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): if isinstance(edit_guidance_scale, list): @@ -1244,9 +1258,6 @@ def __call__( ) noise_guidance_edit[c, :, :, :] = noise_guidance_edit_tmp - # noise_guidance_edit[c] = noise_guidance_edit_tmp - - # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp warmup_inds = torch.tensor(warmup_inds).to(device) if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: @@ -1258,7 +1269,6 @@ def __call__( concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp ) concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) - # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds) noise_guidance_edit_tmp = torch.einsum(